1919from typing import Any , Callable , Dict , List , Optional , Union
2020
2121import torch
22- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
22+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
2323
24- from ...image_processor import VaeImageProcessor
25- from ...loaders import FromSingleFileMixin , LoraLoaderMixin , TextualInversionLoaderMixin
26- from ...models import AutoencoderKL , UNet2DConditionModel
24+ from ...image_processor import PipelineImageInput , VaeImageProcessor
25+ from ...loaders import FromSingleFileMixin , IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
26+ from ...models import AutoencoderKL , ImageProjection , UNet2DConditionModel
2727from ...models .lora import adjust_lora_scale_text_encoder
2828from ...schedulers import LCMScheduler
2929from ...utils import (
@@ -107,7 +107,7 @@ def retrieve_timesteps(
107107
108108
109109class LatentConsistencyModelPipeline (
110- DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , FromSingleFileMixin
110+ DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin , FromSingleFileMixin
111111):
112112 r"""
113113 Pipeline for text-to-image generation using a latent consistency model.
@@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
120120 - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
121121 - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
122122 - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
123+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
123124
124125 Args:
125126 vae ([`AutoencoderKL`]):
@@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
144145 """
145146
146147 model_cpu_offload_seq = "text_encoder->unet->vae"
147- _optional_components = ["safety_checker" , "feature_extractor" ]
148+ _optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
148149 _exclude_from_cpu_offload = ["safety_checker" ]
149150 _callback_tensor_inputs = ["latents" , "denoised" , "prompt_embeds" , "w_embedding" ]
150151
@@ -157,6 +158,7 @@ def __init__(
157158 scheduler : LCMScheduler ,
158159 safety_checker : StableDiffusionSafetyChecker ,
159160 feature_extractor : CLIPImageProcessor ,
161+ image_encoder : Optional [CLIPVisionModelWithProjection ] = None ,
160162 requires_safety_checker : bool = True ,
161163 ):
162164 super ().__init__ ()
@@ -185,6 +187,7 @@ def __init__(
185187 scheduler = scheduler ,
186188 safety_checker = safety_checker ,
187189 feature_extractor = feature_extractor ,
190+ image_encoder = image_encoder ,
188191 )
189192 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
190193 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
@@ -433,6 +436,31 @@ def encode_prompt(
433436
434437 return prompt_embeds , negative_prompt_embeds
435438
439+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
440+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
441+ dtype = next (self .image_encoder .parameters ()).dtype
442+
443+ if not isinstance (image , torch .Tensor ):
444+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
445+
446+ image = image .to (device = device , dtype = dtype )
447+ if output_hidden_states :
448+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
449+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
450+ uncond_image_enc_hidden_states = self .image_encoder (
451+ torch .zeros_like (image ), output_hidden_states = True
452+ ).hidden_states [- 2 ]
453+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
454+ num_images_per_prompt , dim = 0
455+ )
456+ return image_enc_hidden_states , uncond_image_enc_hidden_states
457+ else :
458+ image_embeds = self .image_encoder (image ).image_embeds
459+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
460+ uncond_image_embeds = torch .zeros_like (image_embeds )
461+
462+ return image_embeds , uncond_image_embeds
463+
436464 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
437465 def run_safety_checker (self , image , device , dtype ):
438466 if self .safety_checker is None :
@@ -581,6 +609,7 @@ def __call__(
581609 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
582610 latents : Optional [torch .FloatTensor ] = None ,
583611 prompt_embeds : Optional [torch .FloatTensor ] = None ,
612+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
584613 output_type : Optional [str ] = "pil" ,
585614 return_dict : bool = True ,
586615 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -629,6 +658,8 @@ def __call__(
629658 prompt_embeds (`torch.FloatTensor`, *optional*):
630659 Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
631660 provided, text embeddings are generated from the `prompt` input argument.
661+ ip_adapter_image: (`PipelineImageInput`, *optional*):
662+ Optional image input to work with IP Adapters.
632663 output_type (`str`, *optional*, defaults to `"pil"`):
633664 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
634665 return_dict (`bool`, *optional*, defaults to `True`):
@@ -697,6 +728,12 @@ def __call__(
697728 device = self ._execution_device
698729 # do_classifier_free_guidance = guidance_scale > 1.0
699730
731+ if ip_adapter_image is not None :
732+ output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
733+ image_embeds , negative_image_embeds = self .encode_image (
734+ ip_adapter_image , device , num_images_per_prompt , output_hidden_state
735+ )
736+
700737 # 3. Encode input prompt
701738 lora_scale = (
702739 self .cross_attention_kwargs .get ("scale" , None ) if self .cross_attention_kwargs is not None else None
@@ -748,6 +785,9 @@ def __call__(
748785 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
749786 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , None )
750787
788+ # 7.1 Add image embeds for IP-Adapter
789+ added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
790+
751791 # 8. LCM MultiStep Sampling Loop:
752792 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
753793 self ._num_timesteps = len (timesteps )
@@ -762,6 +802,7 @@ def __call__(
762802 timestep_cond = w_embedding ,
763803 encoder_hidden_states = prompt_embeds ,
764804 cross_attention_kwargs = self .cross_attention_kwargs ,
805+ added_cond_kwargs = added_cond_kwargs ,
765806 return_dict = False ,
766807 )[0 ]
767808
0 commit comments