@@ -32,6 +32,7 @@ def __init__(self, unet):
3232 self .config .enable_cudnn_conv_heuristic_search_algo (False )
3333
3434 def build (self , latent_model_input , t , text_embeddings ):
35+ text_embeddings = torch ._C .amp_white_identity (text_embeddings )
3536 return self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
3637
3738class OneFlowStableDiffusionPipeline (DiffusionPipeline ):
@@ -98,6 +99,8 @@ def __init__(
9899 safety_checker = safety_checker ,
99100 feature_extractor = feature_extractor ,
100101 )
102+ self .unet_graph = UNetGraph (self .unet )
103+ self .unet_compiled = False
101104
102105 def enable_attention_slicing (self , slice_size : Optional [Union [str , int ]] = "auto" ):
103106 r"""
@@ -185,6 +188,8 @@ def __call__(
185188 (nsfw) content, according to the `safety_checker`.
186189 """
187190
191+ from timeit import default_timer as timer
192+ start = timer ()
188193 if "torch_device" in kwargs :
189194 device = kwargs .pop ("torch_device" )
190195 warnings .warn (
@@ -271,12 +276,17 @@ def __call__(
271276 if accepts_eta :
272277 extra_step_kwargs ["eta" ] = eta
273278
274- unet_graph = UNetGraph (self .unet )
275-
276- print ("[oneflow]" , "compiling unet beforehand to make sure the progress bar is more accurate" )
277- i , t = list (enumerate (self .scheduler .timesteps ))[0 ]
278- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
279- unet_graph ._compile (latent_model_input , t , text_embeddings )
279+ compilation_start = timer ()
280+ compilation_time = 0
281+ if self .unet_compiled == False :
282+ print ("[oneflow]" , "compiling unet beforehand to make sure the progress bar is more accurate" )
283+ i , t = list (enumerate (self .scheduler .timesteps ))[0 ]
284+ latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
285+ self .unet_graph ._compile (latent_model_input , t , text_embeddings )
286+ self .unet_compiled = True
287+ self .unet_graph (latent_model_input , t , text_embeddings ) # warmup
288+ compilation_time = timer () - compilation_start
289+ print ("[oneflow]" , "[elapsed(s)]" , "[unet compilation]" , compilation_time )
280290
281291 for i , t in enumerate (self .progress_bar (self .scheduler .timesteps )):
282292 torch ._oneflow_internal .profiler .RangePush (f"denoise-{ i } " )
@@ -289,7 +299,7 @@ def __call__(
289299
290300 # predict the noise residual
291301 torch ._oneflow_internal .profiler .RangePush (f"denoise-{ i } -unet-graph" )
292- noise_pred = unet_graph (latent_model_input , t , text_embeddings )
302+ noise_pred = self . unet_graph (latent_model_input , t , text_embeddings )
293303 torch ._oneflow_internal .profiler .RangePop ()
294304
295305 # perform guidance
@@ -310,6 +320,8 @@ def __call__(
310320 if isinstance (latents , np .ndarray ):
311321 latents = torch .from_numpy (latents )
312322 image = self .vae .decode (latents ).sample
323+ print ("[oneflow]" , "[elapsed(s)]" , "[image]" , timer () - start - compilation_time )
324+ post_process_start = timer ()
313325
314326 image = (image / 2 + 0.5 ).clamp (0 , 1 )
315327 image = image .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
@@ -328,4 +340,6 @@ def __call__(
328340 return (image , has_nsfw_concept )
329341 import torch as og_torch
330342 assert og_torch .cuda .is_initialized () is False
343+
344+ print ("[oneflow]" , "[elapsed(s)]" , "[post-process]" , timer () - post_process_start )
331345 return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
0 commit comments