Skip to content

Commit c74fe8c

Browse files
LCM Add Tests (huggingface#5707)
* lcm add tests * uP * Fix all * uP * Add * all * uP * uP * uP * uP * uP * uP * uP
1 parent 9b3c1fd commit c74fe8c

9 files changed

+316
-8
lines changed

loaders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,11 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
14111411
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
14121412
)
14131413

1414+
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
1415+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
1416+
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
1417+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
1418+
14141419
if len(targeted_files) > 1:
14151420
raise ValueError(
14161421
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,34 @@ def disable_freeu(self):
588588
"""Disables the FreeU mechanism if enabled."""
589589
self.unet.disable_freeu()
590590

591+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
592+
"""
593+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
594+
595+
Args:
596+
timesteps (`torch.Tensor`):
597+
generate embedding vectors at these timesteps
598+
embedding_dim (`int`, *optional*, defaults to 512):
599+
dimension of the embeddings to generate
600+
dtype:
601+
data type of the generated embeddings
602+
603+
Returns:
604+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
605+
"""
606+
assert len(w.shape) == 1
607+
w = w * 1000.0
608+
609+
half_dim = embedding_dim // 2
610+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
611+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
612+
emb = w.to(dtype)[:, None] * emb[None, :]
613+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
614+
if embedding_dim % 2 == 1: # zero pad
615+
emb = torch.nn.functional.pad(emb, (0, 1))
616+
assert emb.shape == (w.shape[0], embedding_dim)
617+
return emb
618+
591619
@property
592620
def guidance_scale(self):
593621
return self._guidance_scale
@@ -605,7 +633,7 @@ def clip_skip(self):
605633
# corresponds to doing no classifier free guidance.
606634
@property
607635
def do_classifier_free_guidance(self):
608-
return self._guidance_scale > 1
636+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
609637

610638
@property
611639
def cross_attention_kwargs(self):
@@ -804,6 +832,14 @@ def __call__(
804832
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
805833
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
806834

835+
# 6.5 Optionally get Guidance Scale Embedding
836+
timestep_cond = None
837+
if self.unet.config.time_cond_proj_dim is not None:
838+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
839+
timestep_cond = self.get_guidance_scale_embedding(
840+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
841+
).to(device=device, dtype=latents.dtype)
842+
807843
# 7. Denoising loop
808844
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
809845
self._num_timesteps = len(timesteps)
@@ -818,6 +854,7 @@ def __call__(
818854
latent_model_input,
819855
t,
820856
encoder_hidden_states=prompt_embeds,
857+
timestep_cond=timestep_cond,
821858
cross_attention_kwargs=self.cross_attention_kwargs,
822859
return_dict=False,
823860
)[0]

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,34 @@ def disable_freeu(self):
646646
"""Disables the FreeU mechanism if enabled."""
647647
self.unet.disable_freeu()
648648

649+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
650+
"""
651+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
652+
653+
Args:
654+
timesteps (`torch.Tensor`):
655+
generate embedding vectors at these timesteps
656+
embedding_dim (`int`, *optional*, defaults to 512):
657+
dimension of the embeddings to generate
658+
dtype:
659+
data type of the generated embeddings
660+
661+
Returns:
662+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
663+
"""
664+
assert len(w.shape) == 1
665+
w = w * 1000.0
666+
667+
half_dim = embedding_dim // 2
668+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
669+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
670+
emb = w.to(dtype)[:, None] * emb[None, :]
671+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
672+
if embedding_dim % 2 == 1: # zero pad
673+
emb = torch.nn.functional.pad(emb, (0, 1))
674+
assert emb.shape == (w.shape[0], embedding_dim)
675+
return emb
676+
649677
@property
650678
def guidance_scale(self):
651679
return self._guidance_scale
@@ -659,7 +687,7 @@ def clip_skip(self):
659687
# corresponds to doing no classifier free guidance.
660688
@property
661689
def do_classifier_free_guidance(self):
662-
return self._guidance_scale > 1
690+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
663691

664692
@property
665693
def cross_attention_kwargs(self):
@@ -849,6 +877,14 @@ def __call__(
849877
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
850878
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
851879

880+
# 7.5 Optionally get Guidance Scale Embedding
881+
timestep_cond = None
882+
if self.unet.config.time_cond_proj_dim is not None:
883+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
884+
timestep_cond = self.get_guidance_scale_embedding(
885+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
886+
).to(device=device, dtype=latents.dtype)
887+
852888
# 8. Denoising loop
853889
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
854890
self._num_timesteps = len(timesteps)
@@ -863,6 +899,7 @@ def __call__(
863899
latent_model_input,
864900
t,
865901
encoder_hidden_states=prompt_embeds,
902+
timestep_cond=timestep_cond,
866903
cross_attention_kwargs=self.cross_attention_kwargs,
867904
return_dict=False,
868905
)[0]

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,35 @@ def disable_freeu(self):
576576
"""Disables the FreeU mechanism if enabled."""
577577
self.unet.disable_freeu()
578578

579+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
580+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
581+
"""
582+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
583+
584+
Args:
585+
timesteps (`torch.Tensor`):
586+
generate embedding vectors at these timesteps
587+
embedding_dim (`int`, *optional*, defaults to 512):
588+
dimension of the embeddings to generate
589+
dtype:
590+
data type of the generated embeddings
591+
592+
Returns:
593+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
594+
"""
595+
assert len(w.shape) == 1
596+
w = w * 1000.0
597+
598+
half_dim = embedding_dim // 2
599+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
600+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
601+
emb = w.to(dtype)[:, None] * emb[None, :]
602+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
603+
if embedding_dim % 2 == 1: # zero pad
604+
emb = torch.nn.functional.pad(emb, (0, 1))
605+
assert emb.shape == (w.shape[0], embedding_dim)
606+
return emb
607+
579608
@property
580609
def guidance_scale(self):
581610
return self._guidance_scale
@@ -593,7 +622,7 @@ def clip_skip(self):
593622
# corresponds to doing no classifier free guidance.
594623
@property
595624
def do_classifier_free_guidance(self):
596-
return self._guidance_scale > 1
625+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
597626

598627
@property
599628
def cross_attention_kwargs(self):
@@ -790,6 +819,14 @@ def __call__(
790819
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
791820
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
792821

822+
# 6.5 Optionally get Guidance Scale Embedding
823+
timestep_cond = None
824+
if self.unet.config.time_cond_proj_dim is not None:
825+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
826+
timestep_cond = self.get_guidance_scale_embedding(
827+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
828+
).to(device=device, dtype=latents.dtype)
829+
793830
# 7. Denoising loop
794831
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
795832
self._num_timesteps = len(timesteps)
@@ -804,6 +841,7 @@ def __call__(
804841
latent_model_input,
805842
t,
806843
encoder_hidden_states=prompt_embeds,
844+
timestep_cond=timestep_cond,
807845
cross_attention_kwargs=self.cross_attention_kwargs,
808846
return_dict=False,
809847
)[0]

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,35 @@ def disable_freeu(self):
640640
"""Disables the FreeU mechanism if enabled."""
641641
self.unet.disable_freeu()
642642

643+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
644+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
645+
"""
646+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
647+
648+
Args:
649+
timesteps (`torch.Tensor`):
650+
generate embedding vectors at these timesteps
651+
embedding_dim (`int`, *optional*, defaults to 512):
652+
dimension of the embeddings to generate
653+
dtype:
654+
data type of the generated embeddings
655+
656+
Returns:
657+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
658+
"""
659+
assert len(w.shape) == 1
660+
w = w * 1000.0
661+
662+
half_dim = embedding_dim // 2
663+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
664+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
665+
emb = w.to(dtype)[:, None] * emb[None, :]
666+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
667+
if embedding_dim % 2 == 1: # zero pad
668+
emb = torch.nn.functional.pad(emb, (0, 1))
669+
assert emb.shape == (w.shape[0], embedding_dim)
670+
return emb
671+
643672
@property
644673
def guidance_scale(self):
645674
return self._guidance_scale
@@ -653,7 +682,7 @@ def clip_skip(self):
653682
# corresponds to doing no classifier free guidance.
654683
@property
655684
def do_classifier_free_guidance(self):
656-
return self._guidance_scale > 1
685+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
657686

658687
@property
659688
def cross_attention_kwargs(self):
@@ -841,6 +870,14 @@ def __call__(
841870
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
842871
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
843872

873+
# 7.5 Optionally get Guidance Scale Embedding
874+
timestep_cond = None
875+
if self.unet.config.time_cond_proj_dim is not None:
876+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
877+
timestep_cond = self.get_guidance_scale_embedding(
878+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
879+
).to(device=device, dtype=latents.dtype)
880+
844881
# 8. Denoising loop
845882
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
846883
self._num_timesteps = len(timesteps)
@@ -855,6 +892,7 @@ def __call__(
855892
latent_model_input,
856893
t,
857894
encoder_hidden_states=prompt_embeds,
895+
timestep_cond=timestep_cond,
858896
cross_attention_kwargs=self.cross_attention_kwargs,
859897
return_dict=False,
860898
)[0]

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,35 @@ def disable_freeu(self):
765765
"""Disables the FreeU mechanism if enabled."""
766766
self.unet.disable_freeu()
767767

768+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
769+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
770+
"""
771+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
772+
773+
Args:
774+
timesteps (`torch.Tensor`):
775+
generate embedding vectors at these timesteps
776+
embedding_dim (`int`, *optional*, defaults to 512):
777+
dimension of the embeddings to generate
778+
dtype:
779+
data type of the generated embeddings
780+
781+
Returns:
782+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
783+
"""
784+
assert len(w.shape) == 1
785+
w = w * 1000.0
786+
787+
half_dim = embedding_dim // 2
788+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
789+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
790+
emb = w.to(dtype)[:, None] * emb[None, :]
791+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
792+
if embedding_dim % 2 == 1: # zero pad
793+
emb = torch.nn.functional.pad(emb, (0, 1))
794+
assert emb.shape == (w.shape[0], embedding_dim)
795+
return emb
796+
768797
@property
769798
def guidance_scale(self):
770799
return self._guidance_scale
@@ -778,7 +807,7 @@ def clip_skip(self):
778807
# corresponds to doing no classifier free guidance.
779808
@property
780809
def do_classifier_free_guidance(self):
781-
return self._guidance_scale > 1
810+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
782811

783812
@property
784813
def cross_attention_kwargs(self):
@@ -1087,6 +1116,14 @@ def __call__(
10871116
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10881117
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
10891118

1119+
# 9.5 Optionally get Guidance Scale Embedding
1120+
timestep_cond = None
1121+
if self.unet.config.time_cond_proj_dim is not None:
1122+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1123+
timestep_cond = self.get_guidance_scale_embedding(
1124+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1125+
).to(device=device, dtype=latents.dtype)
1126+
10901127
# 10. Denoising loop
10911128
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
10921129
self._num_timesteps = len(timesteps)
@@ -1106,6 +1143,7 @@ def __call__(
11061143
latent_model_input,
11071144
t,
11081145
encoder_hidden_states=prompt_embeds,
1146+
timestep_cond=timestep_cond,
11091147
cross_attention_kwargs=self.cross_attention_kwargs,
11101148
return_dict=False,
11111149
)[0]

0 commit comments

Comments
 (0)