1616
1717import torch
1818
19- from diffusers import Transformer2DModel , VQModel
19+ from diffusers import LearnedClassifierFreeSamplingEmbeddings , Transformer2DModel , VQModel
2020from diffusers .schedulers .scheduling_vq_diffusion import VQDiffusionScheduler
2121from transformers import CLIPTextModel , CLIPTokenizer
2222
@@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
5555 text_encoder : CLIPTextModel
5656 tokenizer : CLIPTokenizer
5757 transformer : Transformer2DModel
58+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings
5859 scheduler : VQDiffusionScheduler
5960
6061 def __init__ (
@@ -64,6 +65,7 @@ def __init__(
6465 tokenizer : CLIPTokenizer ,
6566 transformer : Transformer2DModel ,
6667 scheduler : VQDiffusionScheduler ,
68+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings ,
6769 ):
6870 super ().__init__ ()
6971
@@ -73,13 +75,78 @@ def __init__(
7375 text_encoder = text_encoder ,
7476 tokenizer = tokenizer ,
7577 scheduler = scheduler ,
78+ learned_classifier_free_sampling_embeddings = learned_classifier_free_sampling_embeddings ,
7679 )
7780
81+ def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance ):
82+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
83+
84+ # get prompt text embeddings
85+ text_inputs = self .tokenizer (
86+ prompt ,
87+ padding = "max_length" ,
88+ max_length = self .tokenizer .model_max_length ,
89+ return_tensors = "pt" ,
90+ )
91+ text_input_ids = text_inputs .input_ids
92+
93+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
94+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
95+ logger .warning (
96+ "The following part of your input was truncated because CLIP can only handle sequences up to"
97+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
98+ )
99+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
100+ text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
101+
102+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
103+ # While CLIP does normalize the pooled output of the text transformer when combining
104+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
105+ #
106+ # CLIP normalizing the pooled output.
107+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
108+ text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
109+
110+ # duplicate text embeddings for each generation per prompt
111+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
112+
113+ if do_classifier_free_guidance :
114+ if self .learned_classifier_free_sampling_embeddings .learnable :
115+ uncond_embeddings = self .learned_classifier_free_sampling_embeddings .embeddings
116+ uncond_embeddings = uncond_embeddings .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
117+ else :
118+ uncond_tokens = ["" ] * batch_size
119+
120+ max_length = text_input_ids .shape [- 1 ]
121+ uncond_input = self .tokenizer (
122+ uncond_tokens ,
123+ padding = "max_length" ,
124+ max_length = max_length ,
125+ truncation = True ,
126+ return_tensors = "pt" ,
127+ )
128+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
129+ # See comment for normalizing text embeddings
130+ uncond_embeddings = uncond_embeddings / uncond_embeddings .norm (dim = - 1 , keepdim = True )
131+
132+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
133+ seq_len = uncond_embeddings .shape [1 ]
134+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
135+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
136+
137+ # For classifier free guidance, we need to do two forward passes.
138+ # Here we concatenate the unconditional and text embeddings into a single batch
139+ # to avoid doing two forward passes
140+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
141+
142+ return text_embeddings
143+
78144 @torch .no_grad ()
79145 def __call__ (
80146 self ,
81147 prompt : Union [str , List [str ]],
82148 num_inference_steps : int = 100 ,
149+ guidance_scale : float = 5.0 ,
83150 truncation_rate : float = 1.0 ,
84151 num_images_per_prompt : int = 1 ,
85152 generator : Optional [torch .Generator ] = None ,
@@ -98,6 +165,12 @@ def __call__(
98165 num_inference_steps (`int`, *optional*, defaults to 100):
99166 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
100167 expense of slower inference.
168+ guidance_scale (`float`, *optional*, defaults to 7.5):
169+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
170+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
171+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
172+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
173+ usually at the expense of lower image quality.
101174 truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
102175 Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
103176 most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
@@ -137,6 +210,10 @@ def __call__(
137210
138211 batch_size = batch_size * num_images_per_prompt
139212
213+ do_classifier_free_guidance = guidance_scale > 1.0
214+
215+ text_embeddings = self ._encode_prompt (prompt , num_images_per_prompt , do_classifier_free_guidance )
216+
140217 if (callback_steps is None ) or (
141218 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
142219 ):
@@ -145,35 +222,6 @@ def __call__(
145222 f" { type (callback_steps )} ."
146223 )
147224
148- # get prompt text embeddings
149- text_inputs = self .tokenizer (
150- prompt ,
151- padding = "max_length" ,
152- max_length = self .tokenizer .model_max_length ,
153- return_tensors = "pt" ,
154- )
155- text_input_ids = text_inputs .input_ids
156-
157- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
158- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
159- logger .warning (
160- "The following part of your input was truncated because CLIP can only handle sequences up to"
161- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
162- )
163- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
164- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
165-
166- # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167- # While CLIP does normalize the pooled output of the text transformer when combining
168- # the image and text embeddings, CLIP does not directly normalize the last hidden state.
169- #
170- # CLIP normalizing the pooled output.
171- # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172- text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
173-
174- # duplicate text embeddings for each generation per prompt
175- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
176-
177225 # get the initial completely masked latents unless the user supplied it
178226
179227 latents_shape = (batch_size , self .transformer .num_latent_pixels )
@@ -198,9 +246,19 @@ def __call__(
198246 sample = latents
199247
200248 for i , t in enumerate (self .progress_bar (timesteps_tensor )):
249+ # expand the sample if we are doing classifier free guidance
250+ latent_model_input = torch .cat ([sample ] * 2 ) if do_classifier_free_guidance else sample
251+
201252 # predict the un-noised image
202253 # model_output == `log_p_x_0`
203- model_output = self .transformer (sample , encoder_hidden_states = text_embeddings , timestep = t ).sample
254+ model_output = self .transformer (
255+ latent_model_input , encoder_hidden_states = text_embeddings , timestep = t
256+ ).sample
257+
258+ if do_classifier_free_guidance :
259+ model_output_uncond , model_output_text = model_output .chunk (2 )
260+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond )
261+ model_output -= torch .logsumexp (model_output , dim = 1 , keepdim = True )
204262
205263 model_output = self .truncate (model_output , truncation_rate )
206264
0 commit comments