1616
1717import torch
1818
19- from diffusers import Transformer2DModel , VQModel
19+ from diffusers import LearnedClassifierFreeSamplingEmbeddings , Transformer2DModel , VQModel
2020from diffusers .schedulers .scheduling_vq_diffusion import VQDiffusionScheduler
2121from transformers import CLIPTextModel , CLIPTokenizer
2222
@@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
5555 text_encoder : CLIPTextModel
5656 tokenizer : CLIPTokenizer
5757 transformer : Transformer2DModel
58+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings
5859 scheduler : VQDiffusionScheduler
5960
6061 def __init__ (
@@ -64,6 +65,7 @@ def __init__(
6465 tokenizer : CLIPTokenizer ,
6566 transformer : Transformer2DModel ,
6667 scheduler : VQDiffusionScheduler ,
68+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings ,
6769 ):
6870 super ().__init__ ()
6971
@@ -73,13 +75,78 @@ def __init__(
7375 text_encoder = text_encoder ,
7476 tokenizer = tokenizer ,
7577 scheduler = scheduler ,
78+ learned_classifier_free_sampling_embeddings = learned_classifier_free_sampling_embeddings ,
7679 )
7780
81+ def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance ):
82+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
83+
84+ # get prompt text embeddings
85+ text_inputs = self .tokenizer (
86+ prompt ,
87+ padding = "max_length" ,
88+ max_length = self .tokenizer .model_max_length ,
89+ return_tensors = "pt" ,
90+ )
91+ text_input_ids = text_inputs .input_ids
92+
93+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
94+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
95+ logger .warning (
96+ "The following part of your input was truncated because CLIP can only handle sequences up to"
97+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
98+ )
99+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
100+ text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
101+
102+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
103+ # While CLIP does normalize the pooled output of the text transformer when combining
104+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
105+ #
106+ # CLIP normalizing the pooled output.
107+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
108+ text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
109+
110+ # duplicate text embeddings for each generation per prompt
111+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
112+
113+ if do_classifier_free_guidance :
114+ if self .learned_classifier_free_sampling_embeddings .embeddings is None :
115+ uncond_tokens = ["" ] * batch_size
116+
117+ max_length = text_input_ids .shape [- 1 ]
118+ uncond_input = self .tokenizer (
119+ uncond_tokens ,
120+ padding = "max_length" ,
121+ max_length = max_length ,
122+ truncation = True ,
123+ return_tensors = "pt" ,
124+ )
125+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
126+ # See comment for normalizing text embeddings
127+ uncond_embeddings = uncond_embeddings / uncond_embeddings .norm (dim = - 1 , keepdim = True )
128+ else :
129+ uncond_embeddings = self .learned_classifier_free_sampling_embeddings .embeddings
130+ uncond_embeddings = uncond_embeddings .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
131+
132+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
133+ seq_len = uncond_embeddings .shape [1 ]
134+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
135+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
136+
137+ # For classifier free guidance, we need to do two forward passes.
138+ # Here we concatenate the unconditional and text embeddings into a single batch
139+ # to avoid doing two forward passes
140+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
141+
142+ return text_embeddings
143+
78144 @torch .no_grad ()
79145 def __call__ (
80146 self ,
81147 prompt : Union [str , List [str ]],
82148 num_inference_steps : int = 100 ,
149+ guidance_scale : float = 5.0 ,
83150 truncation_rate : float = 1.0 ,
84151 num_images_per_prompt : int = 1 ,
85152 generator : Optional [torch .Generator ] = None ,
@@ -137,6 +204,10 @@ def __call__(
137204
138205 batch_size = batch_size * num_images_per_prompt
139206
207+ do_classifier_free_guidance = guidance_scale > 1.0
208+
209+ text_embeddings = self ._encode_prompt (prompt , num_images_per_prompt , do_classifier_free_guidance )
210+
140211 if (callback_steps is None ) or (
141212 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
142213 ):
@@ -145,35 +216,6 @@ def __call__(
145216 f" { type (callback_steps )} ."
146217 )
147218
148- # get prompt text embeddings
149- text_inputs = self .tokenizer (
150- prompt ,
151- padding = "max_length" ,
152- max_length = self .tokenizer .model_max_length ,
153- return_tensors = "pt" ,
154- )
155- text_input_ids = text_inputs .input_ids
156-
157- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
158- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
159- logger .warning (
160- "The following part of your input was truncated because CLIP can only handle sequences up to"
161- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
162- )
163- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
164- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
165-
166- # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167- # While CLIP does normalize the pooled output of the text transformer when combining
168- # the image and text embeddings, CLIP does not directly normalize the last hidden state.
169- #
170- # CLIP normalizing the pooled output.
171- # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172- text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
173-
174- # duplicate text embeddings for each generation per prompt
175- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
176-
177219 # get the initial completely masked latents unless the user supplied it
178220
179221 latents_shape = (batch_size , self .transformer .num_latent_pixels )
@@ -198,9 +240,19 @@ def __call__(
198240 sample = latents
199241
200242 for i , t in enumerate (self .progress_bar (timesteps_tensor )):
243+ # expand the sample if we are doing classifier free guidance
244+ latent_model_input = torch .cat ([sample ] * 2 ) if do_classifier_free_guidance else sample
245+
201246 # predict the un-noised image
202247 # model_output == `log_p_x_0`
203- model_output = self .transformer (sample , encoder_hidden_states = text_embeddings , timestep = t ).sample
248+ model_output = self .transformer (
249+ latent_model_input , encoder_hidden_states = text_embeddings , timestep = t
250+ ).sample
251+
252+ if do_classifier_free_guidance :
253+ model_output_uncond , model_output_text = model_output .chunk (2 )
254+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond )
255+ model_output -= torch .logsumexp (model_output , dim = 1 , keepdim = True )
204256
205257 model_output = self .truncate (model_output , truncation_rate )
206258
0 commit comments