-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[feat] IP Adapters (author @okotaku ) #5713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 101 commits
08a1828
c4646f8
f3755d4
5887af0
f9aaa54
a45292b
023c2b7
8fe3064
dded7c4
651302b
f10eb25
f051c9e
6031383
95e38ac
3d69688
cacee6d
351180f
2e83d6c
1d64cb8
70fae5c
3aaaa23
2154d01
2807ee3
bc52810
eaf94bb
c22cd90
7cf7f70
03e2961
982a557
6059099
c56503b
59c933a
7ece033
4cb0432
17223d4
8001d24
7887ba7
46c668b
6e28231
3241c96
ef937be
7043443
d7e390f
86b0e4a
d0cf0cc
f2431b3
7fdbf86
ba43e03
1d2b58b
6c0106b
84bcbd6
9b8b11a
d662f6c
a77b1e5
ecb2a5f
c0042c1
44eb034
0f1e364
b2f7af0
6af2112
88efe67
5baa910
abc1372
5e60de5
95797b5
b04cdcf
be73167
fb401d4
86b4f09
426fdb3
66f7023
c904c63
9085797
ab060c4
0d7ef92
b132f50
0dee7fa
188f1d7
36e7903
4f34e08
2d2a7b1
756534b
a17655b
fcf60f3
4d08930
cb451b0
85f3959
eed9900
565c7c0
82a7e4d
5c179b9
eda593b
5c838e4
7183b15
9d7939f
0b15eb1
eec02db
819ed61
6e52db7
d9d7672
f35ce5b
49234b1
e9cdb69
1cd4b23
2ecbc44
584138c
82f0cc9
9471dd9
7ecfcfe
e6c8934
774f0dd
3ab4049
319e003
a106e83
90f9a58
679bcf3
d43f075
be3d3e8
105bd35
1a28c32
dc76816
f06ba21
af88728
9ece001
087417c
f4a04c0
5e4b53d
286cb1a
54b3b21
9ff5f6b
e8f6a85
d50a19f
7e7f1dc
10b79b5
e00dcfe
60049ca
5641a64
9d94e20
fed72fb
b4b32df
f46c2e4
3203eeb
b40e94f
fae2a05
8fe9798
f97a797
c607878
2c2c607
97c68eb
dc1b7eb
dd67bcd
55b6f5c
d4edc4e
75022d0
aaba4d4
c0e9e5d
2fd1685
b5029fb
0162a45
304c790
6645776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,3 +307,140 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b | |
| image = pipeline(prompt=prompt).images[0] | ||
| image | ||
| ``` | ||
|
|
||
| ### IP-Adapter | ||
|
|
||
| [IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter to achieve image prompt capability for the pre-trained text-to-image diffusion models. It is now available to use with most of our Stable Diffusion and Stable Diffusion XL pipelines. You can also use the IP-Adapter with other custom models fine-tuned from the same base model, as well as ControlNet and T2I adapters. Moreover, the image prompt can also work well with the text prompt to accomplish multimodal image generation. | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| You can find the officially available IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter). | ||
|
|
||
| IP-Adapter was contributed by [okotaku](https://github.com/okotaku). | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Let's look at an example where we use IP-Adapter with the Stable Diffusion text-to-image pipeline. | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```py | ||
| from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection | ||
| import torch | ||
| from diffusers.utils import load_image | ||
|
|
||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||
| "h94/IP-Adapter", | ||
| subfolder="models/image_encoder", | ||
| torch_dtype=torch.float16, | ||
| ).to("cuda") | ||
|
|
||
| pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda") | ||
| ``` | ||
|
|
||
| Now you can load the IP-Adapter with [`~loaders.IPAdapterMixin.load_ip_adapter`] method. | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```py | ||
| pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") | ||
| ``` | ||
|
|
||
| IP-Adapter allows you to use both image and text to condition the image generation process. In this example, let's take the cute bear eating pizza that we generated with Textual Inversion, and create a new bear that is similarly cute but wears sunglasses. We can pass the bear image as `ip_adapter_image`, along with a text prompt that mentions "sunglasses". | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```py | ||
| pipeline.set_ip_adapter_scale(0.6) | ||
| image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png") | ||
| generator = torch.Generator(device="cpu").manual_seed(33) | ||
| images = pipeline( | ||
| prompt='best quality, high quality, wearing sunglasses', | ||
| ip_adapter_image=image, | ||
| negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", | ||
| num_inference_steps=50, | ||
| generator=generator, | ||
| ).images | ||
| images[0] | ||
| ``` | ||
|
|
||
| <div class="flex justify-center"> | ||
| <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" /> | ||
| </div> | ||
|
|
||
| <Tip> | ||
|
|
||
| You can use the `pipeline.set_ip_adapter_scale()` method to adjust the ratio of text prompt and image prompt condition. If you only use the image prompt, you should set the scale to be `1.0`. You can lower the scale to get more diversity in the generation, at the cost of less prompt alignment. | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| `scale=0.5` can achieve good results in most cases when you use both text and image prompts. | ||
| </Tip> | ||
|
|
||
| IP-Adapter also works great with Image-to-Image and Inpainting pipelines. Here is an example of how you can use it with Image-to-Image. | ||
|
||
|
|
||
| ```py | ||
| from diffusers import AutoPipelineForImage2Image | ||
| import torch | ||
| from diffusers.utils import load_image | ||
|
|
||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||
| "h94/IP-Adapter", | ||
| subfolder="models/image_encoder", | ||
| torch_dtype=torch.float16, | ||
| ).to("cuda") | ||
|
|
||
| pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda") | ||
|
|
||
| image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") | ||
| ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") | ||
|
|
||
| pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") | ||
| generator = torch.Generator(device="cpu").manual_seed(33) | ||
| images = pipeline( | ||
| prompt='best quality, high quality', | ||
| image = image, | ||
| ip_adapter_image=ip_image, | ||
| num_inference_steps=50, | ||
| generator=generator, | ||
| strength=0.6, | ||
| ).images | ||
| images[0] | ||
| ``` | ||
|
|
||
| IP-Adapters can be used with [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) (SDXL) for text-to-image, image-to-image, and inpainting pipelines. Below is an example for SDXL text-to-image. | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```python | ||
| from diffusers import AutoPipelineForText2Image | ||
| from diffusers.utils import load_image | ||
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | ||
| import torch | ||
|
|
||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||
| "h94/IP-Adapter", | ||
| subfolder="sdxl_models/image_encoder", | ||
| torch_dtype=torch.float16, | ||
| ).to("cuda") | ||
| feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
|
|
||
| pipeline = AutoPipelineForText2Image.from_pretrained( | ||
| "stabilityai/stable-diffusion-xl-base-1.0", | ||
| image_encoder=image_encoder, | ||
| feature_extractor=feature_extractor, | ||
| torch_dtype=torch.float16 | ||
| ).to("cuda") | ||
|
|
||
| image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg") | ||
|
|
||
| pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") | ||
|
|
||
| generator = torch.Generator(device="cpu").manual_seed(33) | ||
| image = pipeline( | ||
| prompt="best quality, high quality", | ||
| ip_adapter_image=image, | ||
| negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", | ||
| num_inference_steps=25, | ||
| generator=generator, | ||
| ).images[0] | ||
| image.save("sdxl_t2i.png") | ||
| ``` | ||
|
|
||
| <div class="flex justify-center"> | ||
| <table border="1"> | ||
| <tr> | ||
| <th>Input Image</th> | ||
| <th>Adapted Image</th> | ||
| </tr> | ||
| <tr> | ||
| <td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg" alt="Input Image"></td> | ||
| <td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png" alt="Adapted Image"></td> | ||
| </tr> | ||
| </table> | ||
| </div> | ||
yiyixuxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,11 +22,21 @@ | |
| import requests | ||
| import safetensors | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from huggingface_hub import hf_hub_download, model_info | ||
| from packaging import version | ||
| from torch import nn | ||
|
|
||
| from . import __version__ | ||
| from .models.attention_processor import ( | ||
| AttnProcessor, | ||
| AttnProcessor2_0, | ||
| IPAdapterAttnProcessor, | ||
| IPAdapterAttnProcessor2_0, | ||
| IPAdapterControlNetAttnProcessor, | ||
| IPAdapterControlNetAttnProcessor2_0, | ||
| ) | ||
| from .models.embeddings import ImageProjection | ||
| from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta | ||
| from .utils import ( | ||
| DIFFUSERS_CACHE, | ||
|
|
@@ -3329,3 +3339,170 @@ def _remove_text_encoder_monkey_patch(self): | |
| else: | ||
| self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) | ||
| self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) | ||
|
|
||
|
|
||
| class IPAdapterMixin: | ||
| """Mixin for handling IP Adapters.""" | ||
|
|
||
| def set_ip_adapter(self): | ||
| unet = self.unet | ||
| attn_procs = {} | ||
| for name in unet.attn_processors.keys(): | ||
| cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
| if name.startswith("mid_block"): | ||
| hidden_size = unet.config.block_out_channels[-1] | ||
| elif name.startswith("up_blocks"): | ||
| block_id = int(name[len("up_blocks.")]) | ||
| hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
| elif name.startswith("down_blocks"): | ||
| block_id = int(name[len("down_blocks.")]) | ||
| hidden_size = unet.config.block_out_channels[block_id] | ||
| if cross_attention_dim is None: | ||
| attn_processor_class = ( | ||
| AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class() | ||
| else: | ||
| attn_processor_class = ( | ||
| IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 | ||
| ).to(dtype=unet.dtype, device=unet.device) | ||
|
|
||
| unet.set_attn_processor(attn_procs) | ||
|
|
||
| if hasattr(self, "controlnet"): | ||
| attn_processor_class = ( | ||
| IPAdapterControlNetAttnProcessor2_0 | ||
| if hasattr(F, "scaled_dot_product_attention") | ||
| else IPAdapterControlNetAttnProcessor | ||
| ) | ||
| self.pipeline.controlnet.set_attn_processor(attn_processor_class()) | ||
|
|
||
| def load_ip_adapter( | ||
|
||
| self, | ||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Parameters: | ||
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | ||
| Can be either: | ||
|
|
||
| - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | ||
| the Hub. | ||
| - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | ||
| with [`ModelMixin.save_pretrained`]. | ||
| - A [torch state | ||
| dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | ||
|
|
||
| cache_dir (`Union[str, os.PathLike]`, *optional*): | ||
| Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | ||
| is not used. | ||
| force_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
| cached versions if they exist. | ||
| resume_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | ||
| incompletely downloaded files are deleted. | ||
| proxies (`Dict[str, str]`, *optional*): | ||
| A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | ||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| Whether to only load local model weights and configuration files or not. If set to `True`, the model | ||
| won't be downloaded from the Hub. | ||
| use_auth_token (`str` or *bool*, *optional*): | ||
| The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | ||
| `diffusers-cli login` (stored in `~/.huggingface`) is used. | ||
| revision (`str`, *optional*, defaults to `"main"`): | ||
| The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | ||
| allowed by Git. | ||
| subfolder (`str`, *optional*, defaults to `""`): | ||
| The subfolder location of a model file within a larger model repository on the Hub or locally. | ||
| """ | ||
| if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: | ||
| raise ValueError("`image_encoder` cannot be None when using IP Adapters.") | ||
|
|
||
| self.set_ip_adapter() | ||
|
|
||
| # Load the main state dict first. | ||
| cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | ||
| force_download = kwargs.pop("force_download", False) | ||
| resume_download = kwargs.pop("resume_download", False) | ||
| proxies = kwargs.pop("proxies", None) | ||
| local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | ||
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
| weight_name = kwargs.pop("weight_name", None) | ||
|
|
||
| user_agent = { | ||
| "file_type": "attn_procs_weights", | ||
| "framework": "pytorch", | ||
| } | ||
|
|
||
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| model_file = _get_model_file( | ||
| pretrained_model_name_or_path_or_dict, | ||
| weights_name=weight_name, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| if weight_name.endswith(".safetensors"): | ||
| state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||
| else: | ||
| state_dict = torch.load(model_file, map_location="cpu") | ||
| else: | ||
| state_dict = pretrained_model_name_or_path_or_dict | ||
|
|
||
| keys = list(state_dict.keys()) | ||
| if keys != ["image_proj", "ip_adapter"]: | ||
| raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") | ||
|
|
||
| # Handle image projection layers. | ||
| clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] | ||
| cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 | ||
|
|
||
| image_projection = ImageProjection( | ||
| cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 | ||
| ) | ||
| image_projection.to(dtype=self.unet.dtype, device=self.unet.device) | ||
|
|
||
| diffusers_state_dict = {} | ||
|
|
||
| diffusers_state_dict.update( | ||
| { | ||
| "image_embeds.weight": state_dict["image_proj"]["proj.weight"], | ||
| "image_embeds.bias": state_dict["image_proj"]["proj.bias"], | ||
| "norm.weight": state_dict["image_proj"]["norm.weight"], | ||
| "norm.bias": state_dict["image_proj"]["norm.bias"], | ||
| } | ||
| ) | ||
|
|
||
| image_projection.load_state_dict(diffusers_state_dict) | ||
|
|
||
| self.unet.encoder_hid_proj = image_projection.to(device=self.unet.device, dtype=self.unet.dtype) | ||
| self.unet.config.encoder_hid_dim_type = "image_proj" | ||
| self.unet.config.encoder_hid_dim = clip_embeddings_dim | ||
|
|
||
| # Handle IP-Adapter cross-attention layers. | ||
| ip_layers = torch.nn.ModuleList( | ||
| [ | ||
| module if isinstance(module, nn.Module) else nn.Identity() | ||
| for module in self.unet.attn_processors.values() | ||
| ] | ||
| ) | ||
| ip_layers.load_state_dict(state_dict["ip_adapter"]) | ||
|
|
||
| def set_ip_adapter_scale(self, scale): | ||
| for attn_processor in self.unet.attn_processors.values(): | ||
| if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): | ||
| attn_processor.scale = scale | ||

Uh oh!
There was an error while loading. Please reload this page.