1919import PIL .Image
2020import torch
2121import torch .nn .functional as F
22- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
22+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
2323
2424from ...image_processor import PipelineImageInput , VaeImageProcessor
25- from ...loaders import FromSingleFileMixin , LoraLoaderMixin , TextualInversionLoaderMixin
25+ from ...loaders import FromSingleFileMixin , IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
2626from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
2727from ...models .lora import adjust_lora_scale_text_encoder
2828from ...schedulers import KarrasDiffusionSchedulers
@@ -130,7 +130,7 @@ def prepare_image(image):
130130
131131
132132class StableDiffusionControlNetImg2ImgPipeline (
133- DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , FromSingleFileMixin
133+ DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , IPAdapterMixin , FromSingleFileMixin
134134):
135135 r"""
136136 Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -140,7 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
140140
141141 The pipeline also inherits the following loading methods:
142142 - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
143-
143+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
144144 Args:
145145 vae ([`AutoencoderKL`]):
146146 Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -166,7 +166,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
166166 """
167167
168168 model_cpu_offload_seq = "text_encoder->unet->vae"
169- _optional_components = ["safety_checker" , "feature_extractor" ]
169+ _optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
170170 _exclude_from_cpu_offload = ["safety_checker" ]
171171 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
172172
@@ -180,6 +180,7 @@ def __init__(
180180 scheduler : KarrasDiffusionSchedulers ,
181181 safety_checker : StableDiffusionSafetyChecker ,
182182 feature_extractor : CLIPImageProcessor ,
183+ image_encoder : CLIPVisionModelWithProjection = None ,
183184 requires_safety_checker : bool = True ,
184185 ):
185186 super ().__init__ ()
@@ -212,6 +213,7 @@ def __init__(
212213 scheduler = scheduler ,
213214 safety_checker = safety_checker ,
214215 feature_extractor = feature_extractor ,
216+ image_encoder = image_encoder ,
215217 )
216218 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
217219 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True )
@@ -468,6 +470,31 @@ def encode_prompt(
468470
469471 return prompt_embeds , negative_prompt_embeds
470472
473+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
474+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
475+ dtype = next (self .image_encoder .parameters ()).dtype
476+
477+ if not isinstance (image , torch .Tensor ):
478+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
479+
480+ image = image .to (device = device , dtype = dtype )
481+ if output_hidden_states :
482+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
483+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
484+ uncond_image_enc_hidden_states = self .image_encoder (
485+ torch .zeros_like (image ), output_hidden_states = True
486+ ).hidden_states [- 2 ]
487+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
488+ num_images_per_prompt , dim = 0
489+ )
490+ return image_enc_hidden_states , uncond_image_enc_hidden_states
491+ else :
492+ image_embeds = self .image_encoder (image ).image_embeds
493+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
494+ uncond_image_embeds = torch .zeros_like (image_embeds )
495+
496+ return image_embeds , uncond_image_embeds
497+
471498 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
472499 def run_safety_checker (self , image , device , dtype ):
473500 if self .safety_checker is None :
@@ -861,6 +888,7 @@ def __call__(
861888 latents : Optional [torch .FloatTensor ] = None ,
862889 prompt_embeds : Optional [torch .FloatTensor ] = None ,
863890 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
891+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
864892 output_type : Optional [str ] = "pil" ,
865893 return_dict : bool = True ,
866894 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -922,6 +950,7 @@ def __call__(
922950 negative_prompt_embeds (`torch.FloatTensor`, *optional*):
923951 Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
924952 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
953+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
925954 output_type (`str`, *optional*, defaults to `"pil"`):
926955 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
927956 return_dict (`bool`, *optional*, defaults to `True`):
@@ -1053,6 +1082,11 @@ def __call__(
10531082 if self .do_classifier_free_guidance :
10541083 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
10551084
1085+ if ip_adapter_image is not None :
1086+ image_embeds , negative_image_embeds = self .encode_image (ip_adapter_image , device , num_images_per_prompt )
1087+ if self .do_classifier_free_guidance :
1088+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1089+
10561090 # 4. Prepare image
10571091 image = self .image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
10581092
@@ -1111,7 +1145,10 @@ def __call__(
11111145 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
11121146 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
11131147
1114- # 7.1 Create tensor stating which controlnets to keep
1148+ # 7.1 Add image embeds for IP-Adapter
1149+ added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
1150+
1151+ # 7.2 Create tensor stating which controlnets to keep
11151152 controlnet_keep = []
11161153 for i in range (len (timesteps )):
11171154 keeps = [
@@ -1171,6 +1208,7 @@ def __call__(
11711208 cross_attention_kwargs = self .cross_attention_kwargs ,
11721209 down_block_additional_residuals = down_block_res_samples ,
11731210 mid_block_additional_residual = mid_block_res_sample ,
1211+ added_cond_kwargs = added_cond_kwargs ,
11741212 return_dict = False ,
11751213 )[0 ]
11761214
0 commit comments