@@ -791,6 +791,46 @@ def disable_freeu(self):
791791 """Disables the FreeU mechanism if enabled."""
792792 self .unet .disable_freeu ()
793793
794+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
795+ def get_guidance_scale_embedding (self , w , embedding_dim = 512 , dtype = torch .float32 ):
796+ """
797+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
798+
799+ Args:
800+ timesteps (`torch.Tensor`):
801+ generate embedding vectors at these timesteps
802+ embedding_dim (`int`, *optional*, defaults to 512):
803+ dimension of the embeddings to generate
804+ dtype:
805+ data type of the generated embeddings
806+
807+ Returns:
808+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
809+ """
810+ assert len (w .shape ) == 1
811+ w = w * 1000.0
812+
813+ half_dim = embedding_dim // 2
814+ emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
815+ emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
816+ emb = w .to (dtype )[:, None ] * emb [None , :]
817+ emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
818+ if embedding_dim % 2 == 1 : # zero pad
819+ emb = torch .nn .functional .pad (emb , (0 , 1 ))
820+ assert emb .shape == (w .shape [0 ], embedding_dim )
821+ return emb
822+
823+ @property
824+ def guidance_scale (self ):
825+ return self ._guidance_scale
826+
827+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
828+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
829+ # corresponds to doing no classifier free guidance.
830+ @property
831+ def do_classifier_free_guidance (self ):
832+ return self ._guidance_scale > 1 and self .unet .config .time_cond_proj_dim is None
833+
794834 @torch .no_grad ()
795835 @replace_example_docstring (EXAMPLE_DOC_STRING )
796836 def __call__ (
@@ -986,6 +1026,8 @@ def __call__(
9861026 control_guidance_end ,
9871027 )
9881028
1029+ self ._guidance_scale = guidance_scale
1030+
9891031 # 2. Define call parameters
9901032 if prompt is not None and isinstance (prompt , str ):
9911033 batch_size = 1
@@ -995,10 +1037,6 @@ def __call__(
9951037 batch_size = prompt_embeds .shape [0 ]
9961038
9971039 device = self ._execution_device
998- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1000- # corresponds to doing no classifier free guidance.
1001- do_classifier_free_guidance = guidance_scale > 1.0
10021040
10031041 if isinstance (controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
10041042 controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (controlnet .nets )
@@ -1024,7 +1062,7 @@ def __call__(
10241062 prompt_2 ,
10251063 device ,
10261064 num_images_per_prompt ,
1027- do_classifier_free_guidance ,
1065+ self . do_classifier_free_guidance ,
10281066 negative_prompt ,
10291067 negative_prompt_2 ,
10301068 prompt_embeds = prompt_embeds ,
@@ -1045,7 +1083,7 @@ def __call__(
10451083 num_images_per_prompt = num_images_per_prompt ,
10461084 device = device ,
10471085 dtype = controlnet .dtype ,
1048- do_classifier_free_guidance = do_classifier_free_guidance ,
1086+ do_classifier_free_guidance = self . do_classifier_free_guidance ,
10491087 guess_mode = guess_mode ,
10501088 )
10511089 height , width = image .shape [- 2 :]
@@ -1061,7 +1099,7 @@ def __call__(
10611099 num_images_per_prompt = num_images_per_prompt ,
10621100 device = device ,
10631101 dtype = controlnet .dtype ,
1064- do_classifier_free_guidance = do_classifier_free_guidance ,
1102+ do_classifier_free_guidance = self . do_classifier_free_guidance ,
10651103 guess_mode = guess_mode ,
10661104 )
10671105
@@ -1089,6 +1127,14 @@ def __call__(
10891127 latents ,
10901128 )
10911129
1130+ # 6.5 Optionally get Guidance Scale Embedding
1131+ timestep_cond = None
1132+ if self .unet .config .time_cond_proj_dim is not None :
1133+ guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
1134+ timestep_cond = self .get_guidance_scale_embedding (
1135+ guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
1136+ ).to (device = device , dtype = latents .dtype )
1137+
10921138 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10931139 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
10941140
@@ -1133,7 +1179,7 @@ def __call__(
11331179 else :
11341180 negative_add_time_ids = add_time_ids
11351181
1136- if do_classifier_free_guidance :
1182+ if self . do_classifier_free_guidance :
11371183 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
11381184 add_text_embeds = torch .cat ([negative_pooled_prompt_embeds , add_text_embeds ], dim = 0 )
11391185 add_time_ids = torch .cat ([negative_add_time_ids , add_time_ids ], dim = 0 )
@@ -1154,13 +1200,13 @@ def __call__(
11541200 if (is_unet_compiled and is_controlnet_compiled ) and is_torch_higher_equal_2_1 :
11551201 torch ._inductor .cudagraph_mark_step_begin ()
11561202 # expand the latents if we are doing classifier free guidance
1157- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
1203+ latent_model_input = torch .cat ([latents ] * 2 ) if self . do_classifier_free_guidance else latents
11581204 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
11591205
11601206 added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
11611207
11621208 # controlnet(s) inference
1163- if guess_mode and do_classifier_free_guidance :
1209+ if guess_mode and self . do_classifier_free_guidance :
11641210 # Infer ControlNet only for the conditional batch.
11651211 control_model_input = latents
11661212 control_model_input = self .scheduler .scale_model_input (control_model_input , t )
@@ -1193,7 +1239,7 @@ def __call__(
11931239 return_dict = False ,
11941240 )
11951241
1196- if guess_mode and do_classifier_free_guidance :
1242+ if guess_mode and self . do_classifier_free_guidance :
11971243 # Infered ControlNet only for the conditional batch.
11981244 # To apply the output of ControlNet to both the unconditional and conditional batches,
11991245 # add 0 to the unconditional batch to keep it unchanged.
@@ -1205,6 +1251,7 @@ def __call__(
12051251 latent_model_input ,
12061252 t ,
12071253 encoder_hidden_states = prompt_embeds ,
1254+ timestep_cond = timestep_cond ,
12081255 cross_attention_kwargs = cross_attention_kwargs ,
12091256 down_block_additional_residuals = down_block_res_samples ,
12101257 mid_block_additional_residual = mid_block_res_sample ,
@@ -1213,7 +1260,7 @@ def __call__(
12131260 )[0 ]
12141261
12151262 # perform guidance
1216- if do_classifier_free_guidance :
1263+ if self . do_classifier_free_guidance :
12171264 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
12181265 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
12191266
0 commit comments