Skip to content

Commit f1fcfde

Browse files
vq diffusion classifier free sampling (#1294)
* vq diffusion classifier free sampling * correct * uP Co-authored-by: Patrick von Platen <[email protected]>
1 parent 09d0546 commit f1fcfde

File tree

4 files changed

+220
-41
lines changed

4 files changed

+220
-41
lines changed

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939

4040
import yaml
4141
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
42-
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
43-
from diffusers.models.attention import Transformer2DModel
42+
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
43+
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
4444
from transformers import CLIPTextModel, CLIPTokenizer
4545
from yaml.loader import FullLoader
4646

@@ -826,6 +826,20 @@ def read_config_file(filename):
826826
transformer_model, checkpoint
827827
)
828828

829+
# classifier free sampling embeddings interlude
830+
831+
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
832+
# model, so we pull them off the checkpoint before the checkpoint is deleted.
833+
834+
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
835+
836+
if learnable_classifier_free_sampling_embeddings:
837+
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
838+
else:
839+
learned_classifier_free_sampling_embeddings_embeddings = None
840+
841+
# done classifier free sampling embeddings interlude
842+
829843
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
830844
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
831845
del diffusers_transformer_checkpoint
@@ -871,13 +885,39 @@ def read_config_file(filename):
871885

872886
# done scheduler
873887

888+
# learned classifier free sampling embeddings
889+
890+
with init_empty_weights():
891+
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
892+
learnable_classifier_free_sampling_embeddings,
893+
hidden_size=text_encoder_model.config.hidden_size,
894+
length=tokenizer_model.model_max_length,
895+
)
896+
897+
learned_classifier_free_sampling_checkpoint = {
898+
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
899+
}
900+
901+
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
902+
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
903+
del learned_classifier_free_sampling_checkpoint
904+
del learned_classifier_free_sampling_embeddings_embeddings
905+
load_checkpoint_and_dispatch(
906+
learned_classifier_free_sampling_embeddings_model,
907+
learned_classifier_free_sampling_checkpoint_file.name,
908+
device_map="auto",
909+
)
910+
911+
# done learned classifier free sampling embeddings
912+
874913
print(f"saving VQ diffusion model, path: {args.dump_path}")
875914

876915
pipe = VQDiffusionPipeline(
877916
vqvae=vqvae_model,
878917
transformer=transformer_model,
879918
tokenizer=tokenizer_model,
880919
text_encoder=text_encoder_model,
920+
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
881921
scheduler=scheduler_model,
882922
)
883923
pipe.save_pretrained(args.dump_path)
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
from .pipeline_vq_diffusion import VQDiffusionPipeline
1+
from ...utils import is_torch_available, is_transformers_available
2+
3+
4+
if is_transformers_available() and is_torch_available():
5+
from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline

src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,37 @@
2020
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
2121
from transformers import CLIPTextModel, CLIPTokenizer
2222

23+
from ...configuration_utils import ConfigMixin, register_to_config
24+
from ...modeling_utils import ModelMixin
2325
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2426
from ...utils import logging
2527

2628

2729
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2830

2931

32+
class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
33+
"""
34+
Utility class for storing learned text embeddings for classifier free sampling
35+
"""
36+
37+
@register_to_config
38+
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
39+
super().__init__()
40+
41+
self.learnable = learnable
42+
43+
if self.learnable:
44+
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
45+
assert length is not None, "learnable=True requires `length` to be set"
46+
47+
embeddings = torch.zeros(length, hidden_size)
48+
else:
49+
embeddings = None
50+
51+
self.embeddings = torch.nn.Parameter(embeddings)
52+
53+
3054
class VQDiffusionPipeline(DiffusionPipeline):
3155
r"""
3256
Pipeline for text-to-image generation using VQ Diffusion
@@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
5579
text_encoder: CLIPTextModel
5680
tokenizer: CLIPTokenizer
5781
transformer: Transformer2DModel
82+
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
5883
scheduler: VQDiffusionScheduler
5984

6085
def __init__(
@@ -64,6 +89,7 @@ def __init__(
6489
tokenizer: CLIPTokenizer,
6590
transformer: Transformer2DModel,
6691
scheduler: VQDiffusionScheduler,
92+
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
6793
):
6894
super().__init__()
6995

@@ -73,13 +99,78 @@ def __init__(
7399
text_encoder=text_encoder,
74100
tokenizer=tokenizer,
75101
scheduler=scheduler,
102+
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
103+
)
104+
105+
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
106+
batch_size = len(prompt) if isinstance(prompt, list) else 1
107+
108+
# get prompt text embeddings
109+
text_inputs = self.tokenizer(
110+
prompt,
111+
padding="max_length",
112+
max_length=self.tokenizer.model_max_length,
113+
return_tensors="pt",
76114
)
115+
text_input_ids = text_inputs.input_ids
116+
117+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
118+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
119+
logger.warning(
120+
"The following part of your input was truncated because CLIP can only handle sequences up to"
121+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
122+
)
123+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
124+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
125+
126+
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
127+
# While CLIP does normalize the pooled output of the text transformer when combining
128+
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
129+
#
130+
# CLIP normalizing the pooled output.
131+
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
132+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
133+
134+
# duplicate text embeddings for each generation per prompt
135+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
136+
137+
if do_classifier_free_guidance:
138+
if self.learned_classifier_free_sampling_embeddings.learnable:
139+
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
140+
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
141+
else:
142+
uncond_tokens = [""] * batch_size
143+
144+
max_length = text_input_ids.shape[-1]
145+
uncond_input = self.tokenizer(
146+
uncond_tokens,
147+
padding="max_length",
148+
max_length=max_length,
149+
truncation=True,
150+
return_tensors="pt",
151+
)
152+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
153+
# See comment for normalizing text embeddings
154+
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
155+
156+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
157+
seq_len = uncond_embeddings.shape[1]
158+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
159+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
160+
161+
# For classifier free guidance, we need to do two forward passes.
162+
# Here we concatenate the unconditional and text embeddings into a single batch
163+
# to avoid doing two forward passes
164+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
165+
166+
return text_embeddings
77167

78168
@torch.no_grad()
79169
def __call__(
80170
self,
81171
prompt: Union[str, List[str]],
82172
num_inference_steps: int = 100,
173+
guidance_scale: float = 5.0,
83174
truncation_rate: float = 1.0,
84175
num_images_per_prompt: int = 1,
85176
generator: Optional[torch.Generator] = None,
@@ -98,6 +189,12 @@ def __call__(
98189
num_inference_steps (`int`, *optional*, defaults to 100):
99190
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
100191
expense of slower inference.
192+
guidance_scale (`float`, *optional*, defaults to 7.5):
193+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
194+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
195+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
196+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
197+
usually at the expense of lower image quality.
101198
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
102199
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
103200
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
@@ -137,6 +234,10 @@ def __call__(
137234

138235
batch_size = batch_size * num_images_per_prompt
139236

237+
do_classifier_free_guidance = guidance_scale > 1.0
238+
239+
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
240+
140241
if (callback_steps is None) or (
141242
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
142243
):
@@ -145,35 +246,6 @@ def __call__(
145246
f" {type(callback_steps)}."
146247
)
147248

148-
# get prompt text embeddings
149-
text_inputs = self.tokenizer(
150-
prompt,
151-
padding="max_length",
152-
max_length=self.tokenizer.model_max_length,
153-
return_tensors="pt",
154-
)
155-
text_input_ids = text_inputs.input_ids
156-
157-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
158-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
159-
logger.warning(
160-
"The following part of your input was truncated because CLIP can only handle sequences up to"
161-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
162-
)
163-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
164-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
165-
166-
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167-
# While CLIP does normalize the pooled output of the text transformer when combining
168-
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
169-
#
170-
# CLIP normalizing the pooled output.
171-
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172-
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
173-
174-
# duplicate text embeddings for each generation per prompt
175-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
176-
177249
# get the initial completely masked latents unless the user supplied it
178250

179251
latents_shape = (batch_size, self.transformer.num_latent_pixels)
@@ -198,9 +270,19 @@ def __call__(
198270
sample = latents
199271

200272
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
273+
# expand the sample if we are doing classifier free guidance
274+
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
275+
201276
# predict the un-noised image
202277
# model_output == `log_p_x_0`
203-
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
278+
model_output = self.transformer(
279+
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
280+
).sample
281+
282+
if do_classifier_free_guidance:
283+
model_output_uncond, model_output_text = model_output.chunk(2)
284+
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
285+
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
204286

205287
model_output = self.truncate(model_output, truncation_rate)
206288

0 commit comments

Comments
 (0)