@@ -142,6 +142,7 @@ def __call__(
142142 latents : Optional [torch .FloatTensor ] = None ,
143143 output_type : Optional [str ] = "pil" ,
144144 return_dict : bool = True ,
145+ compile_unet : bool = True ,
145146 ** kwargs ,
146147 ):
147148 r"""
@@ -179,6 +180,8 @@ def __call__(
179180 return_dict (`bool`, *optional*, defaults to `True`):
180181 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
181182 plain tuple.
183+ compile_unet (`bool`, *optional*, defaults to `True`):
184+ Whether or not to compile unet as nn.graph
182185
183186 Returns:
184187 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -278,15 +281,16 @@ def __call__(
278281
279282 compilation_start = timer ()
280283 compilation_time = 0
281- if self .unet_compiled == False :
282- print ("[oneflow]" , "compiling unet beforehand to make sure the progress bar is more accurate" )
283- i , t = list (enumerate (self .scheduler .timesteps ))[0 ]
284- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
285- self .unet_graph ._compile (latent_model_input , t , text_embeddings )
286- self .unet_compiled = True
287- self .unet_graph (latent_model_input , t , text_embeddings ) # warmup
288- compilation_time = timer () - compilation_start
289- print ("[oneflow]" , "[elapsed(s)]" , "[unet compilation]" , compilation_time )
284+ if compile_unet :
285+ if self .unet_compiled == False :
286+ print ("[oneflow]" , "compiling unet beforehand to make sure the progress bar is more accurate" )
287+ i , t = list (enumerate (self .scheduler .timesteps ))[0 ]
288+ latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
289+ self .unet_graph ._compile (latent_model_input , t , text_embeddings )
290+ self .unet_compiled = True
291+ self .unet_graph (latent_model_input , t , text_embeddings ) # warmup
292+ compilation_time = timer () - compilation_start
293+ print ("[oneflow]" , "[elapsed(s)]" , "[unet compilation]" , compilation_time )
290294
291295 for i , t in enumerate (self .progress_bar (self .scheduler .timesteps )):
292296 torch ._oneflow_internal .profiler .RangePush (f"denoise-{ i } " )
@@ -298,9 +302,12 @@ def __call__(
298302 latent_model_input = latent_model_input / ((sigma ** 2 + 1 ) ** 0.5 )
299303
300304 # predict the noise residual
301- torch ._oneflow_internal .profiler .RangePush (f"denoise-{ i } -unet-graph" )
302- noise_pred = self .unet_graph (latent_model_input , t , text_embeddings )
303- torch ._oneflow_internal .profiler .RangePop ()
305+ if compile_unet :
306+ torch ._oneflow_internal .profiler .RangePush (f"denoise-{ i } -unet-graph" )
307+ noise_pred = self .unet_graph (latent_model_input , t , text_embeddings )
308+ torch ._oneflow_internal .profiler .RangePop ()
309+ else :
310+ noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
304311
305312 # perform guidance
306313 if do_classifier_free_guidance :
0 commit comments