Skip to content

Commit 88bdd97

Browse files
a-r-r-o-wyiyixuxusayakpaul
authored
IP adapter support for most pipelines (#5900)
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py * update tests * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py * revert changes to sd_attend_and_excite and sd_upscale * make style * fix broken tests * update ip-adapter implementation to latest * apply suggestions from review --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 08b453e commit 88bdd97

13 files changed

+380
-43
lines changed

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import PIL.Image
2222
import torch
23-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2424

2525
from ...image_processor import PipelineImageInput, VaeImageProcessor
26-
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27-
from ...models import AutoencoderKL, UNet2DConditionModel
26+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
2828
from ...models.lora import adjust_lora_scale_text_encoder
2929
from ...schedulers import LCMScheduler
3030
from ...utils import (
@@ -129,7 +129,7 @@ def retrieve_timesteps(
129129

130130

131131
class LatentConsistencyModelImg2ImgPipeline(
132-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
132+
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
133133
):
134134
r"""
135135
Pipeline for image-to-image generation using a latent consistency model.
@@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline(
142142
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
143143
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
144144
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
145+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
145146
146147
Args:
147148
vae ([`AutoencoderKL`]):
@@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline(
166167
"""
167168

168169
model_cpu_offload_seq = "text_encoder->unet->vae"
169-
_optional_components = ["safety_checker", "feature_extractor"]
170+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
170171
_exclude_from_cpu_offload = ["safety_checker"]
171172
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
172173

@@ -179,6 +180,7 @@ def __init__(
179180
scheduler: LCMScheduler,
180181
safety_checker: StableDiffusionSafetyChecker,
181182
feature_extractor: CLIPImageProcessor,
183+
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
182184
requires_safety_checker: bool = True,
183185
):
184186
super().__init__()
@@ -191,6 +193,7 @@ def __init__(
191193
scheduler=scheduler,
192194
safety_checker=safety_checker,
193195
feature_extractor=feature_extractor,
196+
image_encoder=image_encoder,
194197
)
195198

196199
if safety_checker is None and requires_safety_checker:
@@ -449,6 +452,31 @@ def encode_prompt(
449452

450453
return prompt_embeds, negative_prompt_embeds
451454

455+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
456+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
457+
dtype = next(self.image_encoder.parameters()).dtype
458+
459+
if not isinstance(image, torch.Tensor):
460+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
461+
462+
image = image.to(device=device, dtype=dtype)
463+
if output_hidden_states:
464+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
465+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
466+
uncond_image_enc_hidden_states = self.image_encoder(
467+
torch.zeros_like(image), output_hidden_states=True
468+
).hidden_states[-2]
469+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
470+
num_images_per_prompt, dim=0
471+
)
472+
return image_enc_hidden_states, uncond_image_enc_hidden_states
473+
else:
474+
image_embeds = self.image_encoder(image).image_embeds
475+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
476+
uncond_image_embeds = torch.zeros_like(image_embeds)
477+
478+
return image_embeds, uncond_image_embeds
479+
452480
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
453481
def run_safety_checker(self, image, device, dtype):
454482
if self.safety_checker is None:
@@ -647,6 +675,7 @@ def __call__(
647675
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
648676
latents: Optional[torch.FloatTensor] = None,
649677
prompt_embeds: Optional[torch.FloatTensor] = None,
678+
ip_adapter_image: Optional[PipelineImageInput] = None,
650679
output_type: Optional[str] = "pil",
651680
return_dict: bool = True,
652681
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -695,6 +724,8 @@ def __call__(
695724
prompt_embeds (`torch.FloatTensor`, *optional*):
696725
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
697726
provided, text embeddings are generated from the `prompt` input argument.
727+
ip_adapter_image: (`PipelineImageInput`, *optional*):
728+
Optional image input to work with IP Adapters.
698729
output_type (`str`, *optional*, defaults to `"pil"`):
699730
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
700731
return_dict (`bool`, *optional*, defaults to `True`):
@@ -758,6 +789,12 @@ def __call__(
758789
device = self._execution_device
759790
# do_classifier_free_guidance = guidance_scale > 1.0
760791

792+
if ip_adapter_image is not None:
793+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
794+
image_embeds, negative_image_embeds = self.encode_image(
795+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
796+
)
797+
761798
# 3. Encode input prompt
762799
lora_scale = (
763800
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
@@ -815,6 +852,9 @@ def __call__(
815852
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
816853
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
817854

855+
# 7.1 Add image embeds for IP-Adapter
856+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
857+
818858
# 8. LCM Multistep Sampling Loop
819859
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
820860
self._num_timesteps = len(timesteps)
@@ -829,6 +869,7 @@ def __call__(
829869
timestep_cond=w_embedding,
830870
encoder_hidden_states=prompt_embeds,
831871
cross_attention_kwargs=self.cross_attention_kwargs,
872+
added_cond_kwargs=added_cond_kwargs,
832873
return_dict=False,
833874
)[0]
834875

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from typing import Any, Callable, Dict, List, Optional, Union
2020

2121
import torch
22-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2323

24-
from ...image_processor import VaeImageProcessor
25-
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26-
from ...models import AutoencoderKL, UNet2DConditionModel
24+
from ...image_processor import PipelineImageInput, VaeImageProcessor
25+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
2727
from ...models.lora import adjust_lora_scale_text_encoder
2828
from ...schedulers import LCMScheduler
2929
from ...utils import (
@@ -107,7 +107,7 @@ def retrieve_timesteps(
107107

108108

109109
class LatentConsistencyModelPipeline(
110-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
110+
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
111111
):
112112
r"""
113113
Pipeline for text-to-image generation using a latent consistency model.
@@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
120120
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
121121
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
122122
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
123+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
123124
124125
Args:
125126
vae ([`AutoencoderKL`]):
@@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
144145
"""
145146

146147
model_cpu_offload_seq = "text_encoder->unet->vae"
147-
_optional_components = ["safety_checker", "feature_extractor"]
148+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
148149
_exclude_from_cpu_offload = ["safety_checker"]
149150
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
150151

@@ -157,6 +158,7 @@ def __init__(
157158
scheduler: LCMScheduler,
158159
safety_checker: StableDiffusionSafetyChecker,
159160
feature_extractor: CLIPImageProcessor,
161+
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
160162
requires_safety_checker: bool = True,
161163
):
162164
super().__init__()
@@ -185,6 +187,7 @@ def __init__(
185187
scheduler=scheduler,
186188
safety_checker=safety_checker,
187189
feature_extractor=feature_extractor,
190+
image_encoder=image_encoder,
188191
)
189192
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
190193
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -433,6 +436,31 @@ def encode_prompt(
433436

434437
return prompt_embeds, negative_prompt_embeds
435438

439+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
440+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
441+
dtype = next(self.image_encoder.parameters()).dtype
442+
443+
if not isinstance(image, torch.Tensor):
444+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
445+
446+
image = image.to(device=device, dtype=dtype)
447+
if output_hidden_states:
448+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
449+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
450+
uncond_image_enc_hidden_states = self.image_encoder(
451+
torch.zeros_like(image), output_hidden_states=True
452+
).hidden_states[-2]
453+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
454+
num_images_per_prompt, dim=0
455+
)
456+
return image_enc_hidden_states, uncond_image_enc_hidden_states
457+
else:
458+
image_embeds = self.image_encoder(image).image_embeds
459+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
460+
uncond_image_embeds = torch.zeros_like(image_embeds)
461+
462+
return image_embeds, uncond_image_embeds
463+
436464
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
437465
def run_safety_checker(self, image, device, dtype):
438466
if self.safety_checker is None:
@@ -581,6 +609,7 @@ def __call__(
581609
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
582610
latents: Optional[torch.FloatTensor] = None,
583611
prompt_embeds: Optional[torch.FloatTensor] = None,
612+
ip_adapter_image: Optional[PipelineImageInput] = None,
584613
output_type: Optional[str] = "pil",
585614
return_dict: bool = True,
586615
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -629,6 +658,8 @@ def __call__(
629658
prompt_embeds (`torch.FloatTensor`, *optional*):
630659
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
631660
provided, text embeddings are generated from the `prompt` input argument.
661+
ip_adapter_image: (`PipelineImageInput`, *optional*):
662+
Optional image input to work with IP Adapters.
632663
output_type (`str`, *optional*, defaults to `"pil"`):
633664
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
634665
return_dict (`bool`, *optional*, defaults to `True`):
@@ -697,6 +728,12 @@ def __call__(
697728
device = self._execution_device
698729
# do_classifier_free_guidance = guidance_scale > 1.0
699730

731+
if ip_adapter_image is not None:
732+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
733+
image_embeds, negative_image_embeds = self.encode_image(
734+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
735+
)
736+
700737
# 3. Encode input prompt
701738
lora_scale = (
702739
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
@@ -748,6 +785,9 @@ def __call__(
748785
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
749786
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
750787

788+
# 7.1 Add image embeds for IP-Adapter
789+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
790+
751791
# 8. LCM MultiStep Sampling Loop:
752792
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
753793
self._num_timesteps = len(timesteps)
@@ -762,6 +802,7 @@ def __call__(
762802
timestep_cond=w_embedding,
763803
encoder_hidden_states=prompt_embeds,
764804
cross_attention_kwargs=self.cross_attention_kwargs,
805+
added_cond_kwargs=added_cond_kwargs,
765806
return_dict=False,
766807
)[0]
767808

0 commit comments

Comments
 (0)