2020from diffusers .schedulers .scheduling_vq_diffusion import VQDiffusionScheduler
2121from transformers import CLIPTextModel , CLIPTokenizer
2222
23+ from ...configuration_utils import ConfigMixin , register_to_config
24+ from ...modeling_utils import ModelMixin
2325from ...pipeline_utils import DiffusionPipeline , ImagePipelineOutput
2426from ...utils import logging
2527
2628
2729logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2830
2931
32+ class LearnedClassifierFreeSamplingEmbeddings (ModelMixin , ConfigMixin ):
33+ """
34+ Utility class for storing learned text embeddings for classifier free sampling
35+ """
36+
37+ @register_to_config
38+ def __init__ (self , learnable : bool , hidden_size : Optional [int ] = None , length : Optional [int ] = None ):
39+ super ().__init__ ()
40+
41+ self .learnable = learnable
42+
43+ if self .learnable :
44+ assert hidden_size is not None , "learnable=True requires `hidden_size` to be set"
45+ assert length is not None , "learnable=True requires `length` to be set"
46+
47+ embeddings = torch .zeros (length , hidden_size )
48+ else :
49+ embeddings = None
50+
51+ self .embeddings = torch .nn .Parameter (embeddings )
52+
53+
3054class VQDiffusionPipeline (DiffusionPipeline ):
3155 r"""
3256 Pipeline for text-to-image generation using VQ Diffusion
@@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
5579 text_encoder : CLIPTextModel
5680 tokenizer : CLIPTokenizer
5781 transformer : Transformer2DModel
82+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings
5883 scheduler : VQDiffusionScheduler
5984
6085 def __init__ (
@@ -64,6 +89,7 @@ def __init__(
6489 tokenizer : CLIPTokenizer ,
6590 transformer : Transformer2DModel ,
6691 scheduler : VQDiffusionScheduler ,
92+ learned_classifier_free_sampling_embeddings : LearnedClassifierFreeSamplingEmbeddings ,
6793 ):
6894 super ().__init__ ()
6995
@@ -73,13 +99,78 @@ def __init__(
7399 text_encoder = text_encoder ,
74100 tokenizer = tokenizer ,
75101 scheduler = scheduler ,
102+ learned_classifier_free_sampling_embeddings = learned_classifier_free_sampling_embeddings ,
103+ )
104+
105+ def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance ):
106+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
107+
108+ # get prompt text embeddings
109+ text_inputs = self .tokenizer (
110+ prompt ,
111+ padding = "max_length" ,
112+ max_length = self .tokenizer .model_max_length ,
113+ return_tensors = "pt" ,
76114 )
115+ text_input_ids = text_inputs .input_ids
116+
117+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
118+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
119+ logger .warning (
120+ "The following part of your input was truncated because CLIP can only handle sequences up to"
121+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
122+ )
123+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
124+ text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
125+
126+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
127+ # While CLIP does normalize the pooled output of the text transformer when combining
128+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
129+ #
130+ # CLIP normalizing the pooled output.
131+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
132+ text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
133+
134+ # duplicate text embeddings for each generation per prompt
135+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
136+
137+ if do_classifier_free_guidance :
138+ if self .learned_classifier_free_sampling_embeddings .learnable :
139+ uncond_embeddings = self .learned_classifier_free_sampling_embeddings .embeddings
140+ uncond_embeddings = uncond_embeddings .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
141+ else :
142+ uncond_tokens = ["" ] * batch_size
143+
144+ max_length = text_input_ids .shape [- 1 ]
145+ uncond_input = self .tokenizer (
146+ uncond_tokens ,
147+ padding = "max_length" ,
148+ max_length = max_length ,
149+ truncation = True ,
150+ return_tensors = "pt" ,
151+ )
152+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
153+ # See comment for normalizing text embeddings
154+ uncond_embeddings = uncond_embeddings / uncond_embeddings .norm (dim = - 1 , keepdim = True )
155+
156+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
157+ seq_len = uncond_embeddings .shape [1 ]
158+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
159+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
160+
161+ # For classifier free guidance, we need to do two forward passes.
162+ # Here we concatenate the unconditional and text embeddings into a single batch
163+ # to avoid doing two forward passes
164+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
165+
166+ return text_embeddings
77167
78168 @torch .no_grad ()
79169 def __call__ (
80170 self ,
81171 prompt : Union [str , List [str ]],
82172 num_inference_steps : int = 100 ,
173+ guidance_scale : float = 5.0 ,
83174 truncation_rate : float = 1.0 ,
84175 num_images_per_prompt : int = 1 ,
85176 generator : Optional [torch .Generator ] = None ,
@@ -98,6 +189,12 @@ def __call__(
98189 num_inference_steps (`int`, *optional*, defaults to 100):
99190 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
100191 expense of slower inference.
192+ guidance_scale (`float`, *optional*, defaults to 7.5):
193+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
194+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
195+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
196+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
197+ usually at the expense of lower image quality.
101198 truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
102199 Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
103200 most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
@@ -137,6 +234,10 @@ def __call__(
137234
138235 batch_size = batch_size * num_images_per_prompt
139236
237+ do_classifier_free_guidance = guidance_scale > 1.0
238+
239+ text_embeddings = self ._encode_prompt (prompt , num_images_per_prompt , do_classifier_free_guidance )
240+
140241 if (callback_steps is None ) or (
141242 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
142243 ):
@@ -145,35 +246,6 @@ def __call__(
145246 f" { type (callback_steps )} ."
146247 )
147248
148- # get prompt text embeddings
149- text_inputs = self .tokenizer (
150- prompt ,
151- padding = "max_length" ,
152- max_length = self .tokenizer .model_max_length ,
153- return_tensors = "pt" ,
154- )
155- text_input_ids = text_inputs .input_ids
156-
157- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
158- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
159- logger .warning (
160- "The following part of your input was truncated because CLIP can only handle sequences up to"
161- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
162- )
163- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
164- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
165-
166- # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167- # While CLIP does normalize the pooled output of the text transformer when combining
168- # the image and text embeddings, CLIP does not directly normalize the last hidden state.
169- #
170- # CLIP normalizing the pooled output.
171- # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172- text_embeddings = text_embeddings / text_embeddings .norm (dim = - 1 , keepdim = True )
173-
174- # duplicate text embeddings for each generation per prompt
175- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
176-
177249 # get the initial completely masked latents unless the user supplied it
178250
179251 latents_shape = (batch_size , self .transformer .num_latent_pixels )
@@ -198,9 +270,19 @@ def __call__(
198270 sample = latents
199271
200272 for i , t in enumerate (self .progress_bar (timesteps_tensor )):
273+ # expand the sample if we are doing classifier free guidance
274+ latent_model_input = torch .cat ([sample ] * 2 ) if do_classifier_free_guidance else sample
275+
201276 # predict the un-noised image
202277 # model_output == `log_p_x_0`
203- model_output = self .transformer (sample , encoder_hidden_states = text_embeddings , timestep = t ).sample
278+ model_output = self .transformer (
279+ latent_model_input , encoder_hidden_states = text_embeddings , timestep = t
280+ ).sample
281+
282+ if do_classifier_free_guidance :
283+ model_output_uncond , model_output_text = model_output .chunk (2 )
284+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond )
285+ model_output -= torch .logsumexp (model_output , dim = 1 , keepdim = True )
204286
205287 model_output = self .truncate (model_output , truncation_rate )
206288
0 commit comments