Skip to content

Commit 08984ab

Browse files
committed
vq diffusion classifier free sampling
1 parent 8a73064 commit 08984ab

File tree

7 files changed

+275
-37
lines changed

7 files changed

+275
-37
lines changed

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,13 @@
3939

4040
import yaml
4141
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
42-
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
43-
from diffusers.models.attention import Transformer2DModel
42+
from diffusers import (
43+
LearnedClassifierFreeSamplingEmbeddings,
44+
Transformer2DModel,
45+
VQDiffusionPipeline,
46+
VQDiffusionScheduler,
47+
VQModel,
48+
)
4449
from transformers import CLIPTextModel, CLIPTokenizer
4550
from yaml.loader import FullLoader
4651

@@ -826,6 +831,20 @@ def read_config_file(filename):
826831
transformer_model, checkpoint
827832
)
828833

834+
# classifier free sampling embeddings interlude
835+
836+
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
837+
# model, so we pull them off the checkpoint before the checkpoint is deleted.
838+
839+
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
840+
841+
if learnable_classifier_free_sampling_embeddings:
842+
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
843+
else:
844+
learned_classifier_free_sampling_embeddings_embeddings = None
845+
846+
# done classifier free sampling embeddings interlude
847+
829848
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
830849
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
831850
del diffusers_transformer_checkpoint
@@ -871,13 +890,39 @@ def read_config_file(filename):
871890

872891
# done scheduler
873892

893+
# learned classifier free sampling embeddings
894+
895+
with init_empty_weights():
896+
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
897+
learnable_classifier_free_sampling_embeddings,
898+
hidden_size=text_encoder_model.config.hidden_size,
899+
length=tokenizer_model.model_max_length,
900+
)
901+
902+
learned_classifier_free_sampling_checkpoint = {
903+
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
904+
}
905+
906+
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
907+
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
908+
del learned_classifier_free_sampling_checkpoint
909+
del learned_classifier_free_sampling_embeddings_embeddings
910+
load_checkpoint_and_dispatch(
911+
learned_classifier_free_sampling_embeddings_model,
912+
learned_classifier_free_sampling_checkpoint_file.name,
913+
device_map="auto",
914+
)
915+
916+
# done learned classifier free sampling embeddings
917+
874918
print(f"saving VQ diffusion model, path: {args.dump_path}")
875919

876920
pipe = VQDiffusionPipeline(
877921
vqvae=vqvae_model,
878922
transformer=transformer_model,
879923
tokenizer=tokenizer_model,
880924
text_encoder=text_encoder_model,
925+
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
881926
scheduler=scheduler_model,
882927
)
883928
pipe.save_pretrained(args.dump_path)

src/diffusers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@
1818

1919
if is_torch_available():
2020
from .modeling_utils import ModelMixin
21-
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
21+
from .models import (
22+
AutoencoderKL,
23+
LearnedClassifierFreeSamplingEmbeddings,
24+
Transformer2DModel,
25+
UNet1DModel,
26+
UNet2DConditionModel,
27+
UNet2DModel,
28+
VQModel,
29+
)
2230
from .optimization import (
2331
get_constant_schedule,
2432
get_constant_schedule_with_warmup,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if is_torch_available():
1919
from .attention import Transformer2DModel
20+
from .embeddings import LearnedClassifierFreeSamplingEmbeddings
2021
from .unet_1d import UNet1DModel
2122
from .unet_2d import UNet2DModel
2223
from .unet_2d_condition import UNet2DConditionModel

src/diffusers/models/embeddings.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15+
from typing import Optional
1516

1617
import numpy as np
1718
import torch
1819
from torch import nn
1920

21+
from diffusers.configuration_utils import ConfigMixin, register_to_config
22+
from diffusers.modeling_utils import ModelMixin
23+
2024

2125
def get_timestep_embedding(
2226
timesteps: torch.Tensor,
@@ -198,3 +202,25 @@ def forward(self, index):
198202
emb = emb + pos_emb[:, : emb.shape[1], :]
199203

200204
return emb
205+
206+
207+
class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
208+
"""
209+
Utility class for storing learned text embeddings for classifier free sampling
210+
"""
211+
212+
@register_to_config
213+
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
214+
super().__init__()
215+
216+
self.learnable = learnable
217+
218+
if self.learnable:
219+
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
220+
assert length is not None, "learnable=True requires `length` to be set"
221+
222+
embeddings = torch.zeros(length, hidden_size)
223+
else:
224+
embeddings = None
225+
226+
self.embeddings = nn.Parameter(embeddings)

src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from diffusers import Transformer2DModel, VQModel
19+
from diffusers import LearnedClassifierFreeSamplingEmbeddings, Transformer2DModel, VQModel
2020
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
2121
from transformers import CLIPTextModel, CLIPTokenizer
2222

@@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
5555
text_encoder: CLIPTextModel
5656
tokenizer: CLIPTokenizer
5757
transformer: Transformer2DModel
58+
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
5859
scheduler: VQDiffusionScheduler
5960

6061
def __init__(
@@ -64,6 +65,7 @@ def __init__(
6465
tokenizer: CLIPTokenizer,
6566
transformer: Transformer2DModel,
6667
scheduler: VQDiffusionScheduler,
68+
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
6769
):
6870
super().__init__()
6971

@@ -73,13 +75,78 @@ def __init__(
7375
text_encoder=text_encoder,
7476
tokenizer=tokenizer,
7577
scheduler=scheduler,
78+
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
7679
)
7780

81+
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
82+
batch_size = len(prompt) if isinstance(prompt, list) else 1
83+
84+
# get prompt text embeddings
85+
text_inputs = self.tokenizer(
86+
prompt,
87+
padding="max_length",
88+
max_length=self.tokenizer.model_max_length,
89+
return_tensors="pt",
90+
)
91+
text_input_ids = text_inputs.input_ids
92+
93+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
94+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
95+
logger.warning(
96+
"The following part of your input was truncated because CLIP can only handle sequences up to"
97+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
98+
)
99+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
100+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
101+
102+
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
103+
# While CLIP does normalize the pooled output of the text transformer when combining
104+
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
105+
#
106+
# CLIP normalizing the pooled output.
107+
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
108+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
109+
110+
# duplicate text embeddings for each generation per prompt
111+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
112+
113+
if do_classifier_free_guidance:
114+
if self.learned_classifier_free_sampling_embeddings.learnable:
115+
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
116+
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
117+
else:
118+
uncond_tokens = [""] * batch_size
119+
120+
max_length = text_input_ids.shape[-1]
121+
uncond_input = self.tokenizer(
122+
uncond_tokens,
123+
padding="max_length",
124+
max_length=max_length,
125+
truncation=True,
126+
return_tensors="pt",
127+
)
128+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
129+
# See comment for normalizing text embeddings
130+
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
131+
132+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
133+
seq_len = uncond_embeddings.shape[1]
134+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
135+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
136+
137+
# For classifier free guidance, we need to do two forward passes.
138+
# Here we concatenate the unconditional and text embeddings into a single batch
139+
# to avoid doing two forward passes
140+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
141+
142+
return text_embeddings
143+
78144
@torch.no_grad()
79145
def __call__(
80146
self,
81147
prompt: Union[str, List[str]],
82148
num_inference_steps: int = 100,
149+
guidance_scale: float = 5.0,
83150
truncation_rate: float = 1.0,
84151
num_images_per_prompt: int = 1,
85152
generator: Optional[torch.Generator] = None,
@@ -98,6 +165,12 @@ def __call__(
98165
num_inference_steps (`int`, *optional*, defaults to 100):
99166
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
100167
expense of slower inference.
168+
guidance_scale (`float`, *optional*, defaults to 7.5):
169+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
170+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
171+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
172+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
173+
usually at the expense of lower image quality.
101174
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
102175
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
103176
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
@@ -137,6 +210,10 @@ def __call__(
137210

138211
batch_size = batch_size * num_images_per_prompt
139212

213+
do_classifier_free_guidance = guidance_scale > 1.0
214+
215+
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
216+
140217
if (callback_steps is None) or (
141218
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
142219
):
@@ -145,35 +222,6 @@ def __call__(
145222
f" {type(callback_steps)}."
146223
)
147224

148-
# get prompt text embeddings
149-
text_inputs = self.tokenizer(
150-
prompt,
151-
padding="max_length",
152-
max_length=self.tokenizer.model_max_length,
153-
return_tensors="pt",
154-
)
155-
text_input_ids = text_inputs.input_ids
156-
157-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
158-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
159-
logger.warning(
160-
"The following part of your input was truncated because CLIP can only handle sequences up to"
161-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
162-
)
163-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
164-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
165-
166-
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167-
# While CLIP does normalize the pooled output of the text transformer when combining
168-
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
169-
#
170-
# CLIP normalizing the pooled output.
171-
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172-
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
173-
174-
# duplicate text embeddings for each generation per prompt
175-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
176-
177225
# get the initial completely masked latents unless the user supplied it
178226

179227
latents_shape = (batch_size, self.transformer.num_latent_pixels)
@@ -198,9 +246,19 @@ def __call__(
198246
sample = latents
199247

200248
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
249+
# expand the sample if we are doing classifier free guidance
250+
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
251+
201252
# predict the un-noised image
202253
# model_output == `log_p_x_0`
203-
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
254+
model_output = self.transformer(
255+
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
256+
).sample
257+
258+
if do_classifier_free_guidance:
259+
model_output_uncond, model_output_text = model_output.chunk(2)
260+
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
261+
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
204262

205263
model_output = self.truncate(model_output, truncation_rate)
206264

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs):
3434
requires_backends(cls, ["torch"])
3535

3636

37+
class LearnedClassifierFreeSamplingEmbeddings(metaclass=DummyObject):
38+
_backends = ["torch"]
39+
40+
def __init__(self, *args, **kwargs):
41+
requires_backends(self, ["torch"])
42+
43+
@classmethod
44+
def from_config(cls, *args, **kwargs):
45+
requires_backends(cls, ["torch"])
46+
47+
@classmethod
48+
def from_pretrained(cls, *args, **kwargs):
49+
requires_backends(cls, ["torch"])
50+
51+
3752
class Transformer2DModel(metaclass=DummyObject):
3853
_backends = ["torch"]
3954

0 commit comments

Comments
 (0)