Skip to content

Commit 64cbd8e

Browse files
authored
Support LCM in ControlNet and Adapter pipelines. (#5822)
* support lcm * fix tests * fix tests
1 parent 038b42d commit 64cbd8e

File tree

8 files changed

+362
-63
lines changed

8 files changed

+362
-63
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,46 @@ def disable_freeu(self):
726726
"""Disables the FreeU mechanism if enabled."""
727727
self.unet.disable_freeu()
728728

729+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
730+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
731+
"""
732+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
733+
734+
Args:
735+
timesteps (`torch.Tensor`):
736+
generate embedding vectors at these timesteps
737+
embedding_dim (`int`, *optional*, defaults to 512):
738+
dimension of the embeddings to generate
739+
dtype:
740+
data type of the generated embeddings
741+
742+
Returns:
743+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
744+
"""
745+
assert len(w.shape) == 1
746+
w = w * 1000.0
747+
748+
half_dim = embedding_dim // 2
749+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
750+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
751+
emb = w.to(dtype)[:, None] * emb[None, :]
752+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
753+
if embedding_dim % 2 == 1: # zero pad
754+
emb = torch.nn.functional.pad(emb, (0, 1))
755+
assert emb.shape == (w.shape[0], embedding_dim)
756+
return emb
757+
758+
@property
759+
def guidance_scale(self):
760+
return self._guidance_scale
761+
762+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
763+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
764+
# corresponds to doing no classifier free guidance.
765+
@property
766+
def do_classifier_free_guidance(self):
767+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
768+
729769
@torch.no_grad()
730770
@replace_example_docstring(EXAMPLE_DOC_STRING)
731771
def __call__(
@@ -863,6 +903,8 @@ def __call__(
863903
control_guidance_end,
864904
)
865905

906+
self._guidance_scale = guidance_scale
907+
866908
# 2. Define call parameters
867909
if prompt is not None and isinstance(prompt, str):
868910
batch_size = 1
@@ -872,10 +914,6 @@ def __call__(
872914
batch_size = prompt_embeds.shape[0]
873915

874916
device = self._execution_device
875-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
876-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
877-
# corresponds to doing no classifier free guidance.
878-
do_classifier_free_guidance = guidance_scale > 1.0
879917

880918
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
881919
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -895,7 +933,7 @@ def __call__(
895933
prompt,
896934
device,
897935
num_images_per_prompt,
898-
do_classifier_free_guidance,
936+
self.do_classifier_free_guidance,
899937
negative_prompt,
900938
prompt_embeds=prompt_embeds,
901939
negative_prompt_embeds=negative_prompt_embeds,
@@ -905,7 +943,7 @@ def __call__(
905943
# For classifier free guidance, we need to do two forward passes.
906944
# Here we concatenate the unconditional and text embeddings into a single batch
907945
# to avoid doing two forward passes
908-
if do_classifier_free_guidance:
946+
if self.do_classifier_free_guidance:
909947
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
910948

911949
# 4. Prepare image
@@ -918,7 +956,7 @@ def __call__(
918956
num_images_per_prompt=num_images_per_prompt,
919957
device=device,
920958
dtype=controlnet.dtype,
921-
do_classifier_free_guidance=do_classifier_free_guidance,
959+
do_classifier_free_guidance=self.do_classifier_free_guidance,
922960
guess_mode=guess_mode,
923961
)
924962
height, width = image.shape[-2:]
@@ -934,7 +972,7 @@ def __call__(
934972
num_images_per_prompt=num_images_per_prompt,
935973
device=device,
936974
dtype=controlnet.dtype,
937-
do_classifier_free_guidance=do_classifier_free_guidance,
975+
do_classifier_free_guidance=self.do_classifier_free_guidance,
938976
guess_mode=guess_mode,
939977
)
940978

@@ -962,6 +1000,14 @@ def __call__(
9621000
latents,
9631001
)
9641002

1003+
# 6.5 Optionally get Guidance Scale Embedding
1004+
timestep_cond = None
1005+
if self.unet.config.time_cond_proj_dim is not None:
1006+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1007+
timestep_cond = self.get_guidance_scale_embedding(
1008+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1009+
).to(device=device, dtype=latents.dtype)
1010+
9651011
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
9661012
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
9671013

@@ -986,11 +1032,11 @@ def __call__(
9861032
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
9871033
torch._inductor.cudagraph_mark_step_begin()
9881034
# expand the latents if we are doing classifier free guidance
989-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1035+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
9901036
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
9911037

9921038
# controlnet(s) inference
993-
if guess_mode and do_classifier_free_guidance:
1039+
if guess_mode and self.do_classifier_free_guidance:
9941040
# Infer ControlNet only for the conditional batch.
9951041
control_model_input = latents
9961042
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1017,7 +1063,7 @@ def __call__(
10171063
return_dict=False,
10181064
)
10191065

1020-
if guess_mode and do_classifier_free_guidance:
1066+
if guess_mode and self.do_classifier_free_guidance:
10211067
# Infered ControlNet only for the conditional batch.
10221068
# To apply the output of ControlNet to both the unconditional and conditional batches,
10231069
# add 0 to the unconditional batch to keep it unchanged.
@@ -1029,14 +1075,15 @@ def __call__(
10291075
latent_model_input,
10301076
t,
10311077
encoder_hidden_states=prompt_embeds,
1078+
timestep_cond=timestep_cond,
10321079
cross_attention_kwargs=cross_attention_kwargs,
10331080
down_block_additional_residuals=down_block_res_samples,
10341081
mid_block_additional_residual=mid_block_res_sample,
10351082
return_dict=False,
10361083
)[0]
10371084

10381085
# perform guidance
1039-
if do_classifier_free_guidance:
1086+
if self.do_classifier_free_guidance:
10401087
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
10411088
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
10421089

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,46 @@ def disable_freeu(self):
791791
"""Disables the FreeU mechanism if enabled."""
792792
self.unet.disable_freeu()
793793

794+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
795+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
796+
"""
797+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
798+
799+
Args:
800+
timesteps (`torch.Tensor`):
801+
generate embedding vectors at these timesteps
802+
embedding_dim (`int`, *optional*, defaults to 512):
803+
dimension of the embeddings to generate
804+
dtype:
805+
data type of the generated embeddings
806+
807+
Returns:
808+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
809+
"""
810+
assert len(w.shape) == 1
811+
w = w * 1000.0
812+
813+
half_dim = embedding_dim // 2
814+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
815+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
816+
emb = w.to(dtype)[:, None] * emb[None, :]
817+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
818+
if embedding_dim % 2 == 1: # zero pad
819+
emb = torch.nn.functional.pad(emb, (0, 1))
820+
assert emb.shape == (w.shape[0], embedding_dim)
821+
return emb
822+
823+
@property
824+
def guidance_scale(self):
825+
return self._guidance_scale
826+
827+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
828+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
829+
# corresponds to doing no classifier free guidance.
830+
@property
831+
def do_classifier_free_guidance(self):
832+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
833+
794834
@torch.no_grad()
795835
@replace_example_docstring(EXAMPLE_DOC_STRING)
796836
def __call__(
@@ -986,6 +1026,8 @@ def __call__(
9861026
control_guidance_end,
9871027
)
9881028

1029+
self._guidance_scale = guidance_scale
1030+
9891031
# 2. Define call parameters
9901032
if prompt is not None and isinstance(prompt, str):
9911033
batch_size = 1
@@ -995,10 +1037,6 @@ def __call__(
9951037
batch_size = prompt_embeds.shape[0]
9961038

9971039
device = self._execution_device
998-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1000-
# corresponds to doing no classifier free guidance.
1001-
do_classifier_free_guidance = guidance_scale > 1.0
10021040

10031041
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
10041042
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1024,7 +1062,7 @@ def __call__(
10241062
prompt_2,
10251063
device,
10261064
num_images_per_prompt,
1027-
do_classifier_free_guidance,
1065+
self.do_classifier_free_guidance,
10281066
negative_prompt,
10291067
negative_prompt_2,
10301068
prompt_embeds=prompt_embeds,
@@ -1045,7 +1083,7 @@ def __call__(
10451083
num_images_per_prompt=num_images_per_prompt,
10461084
device=device,
10471085
dtype=controlnet.dtype,
1048-
do_classifier_free_guidance=do_classifier_free_guidance,
1086+
do_classifier_free_guidance=self.do_classifier_free_guidance,
10491087
guess_mode=guess_mode,
10501088
)
10511089
height, width = image.shape[-2:]
@@ -1061,7 +1099,7 @@ def __call__(
10611099
num_images_per_prompt=num_images_per_prompt,
10621100
device=device,
10631101
dtype=controlnet.dtype,
1064-
do_classifier_free_guidance=do_classifier_free_guidance,
1102+
do_classifier_free_guidance=self.do_classifier_free_guidance,
10651103
guess_mode=guess_mode,
10661104
)
10671105

@@ -1089,6 +1127,14 @@ def __call__(
10891127
latents,
10901128
)
10911129

1130+
# 6.5 Optionally get Guidance Scale Embedding
1131+
timestep_cond = None
1132+
if self.unet.config.time_cond_proj_dim is not None:
1133+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1134+
timestep_cond = self.get_guidance_scale_embedding(
1135+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1136+
).to(device=device, dtype=latents.dtype)
1137+
10921138
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10931139
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
10941140

@@ -1133,7 +1179,7 @@ def __call__(
11331179
else:
11341180
negative_add_time_ids = add_time_ids
11351181

1136-
if do_classifier_free_guidance:
1182+
if self.do_classifier_free_guidance:
11371183
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
11381184
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
11391185
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -1154,13 +1200,13 @@ def __call__(
11541200
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
11551201
torch._inductor.cudagraph_mark_step_begin()
11561202
# expand the latents if we are doing classifier free guidance
1157-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1203+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
11581204
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
11591205

11601206
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
11611207

11621208
# controlnet(s) inference
1163-
if guess_mode and do_classifier_free_guidance:
1209+
if guess_mode and self.do_classifier_free_guidance:
11641210
# Infer ControlNet only for the conditional batch.
11651211
control_model_input = latents
11661212
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1193,7 +1239,7 @@ def __call__(
11931239
return_dict=False,
11941240
)
11951241

1196-
if guess_mode and do_classifier_free_guidance:
1242+
if guess_mode and self.do_classifier_free_guidance:
11971243
# Infered ControlNet only for the conditional batch.
11981244
# To apply the output of ControlNet to both the unconditional and conditional batches,
11991245
# add 0 to the unconditional batch to keep it unchanged.
@@ -1205,6 +1251,7 @@ def __call__(
12051251
latent_model_input,
12061252
t,
12071253
encoder_hidden_states=prompt_embeds,
1254+
timestep_cond=timestep_cond,
12081255
cross_attention_kwargs=cross_attention_kwargs,
12091256
down_block_additional_residuals=down_block_res_samples,
12101257
mid_block_additional_residual=mid_block_res_sample,
@@ -1213,7 +1260,7 @@ def __call__(
12131260
)[0]
12141261

12151262
# perform guidance
1216-
if do_classifier_free_guidance:
1263+
if self.do_classifier_free_guidance:
12171264
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
12181265
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
12191266

0 commit comments

Comments
 (0)