Skip to content

Commit b978334

Browse files
[@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913)
* finalize * finalize * finalize * add slow test * add slow test * add slow test * Fix more * add slow test * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * Better * Fix more * Fix more * add slow test * Add auto pipelines * add slow test * Add all * add slow test * add slow test * add slow test * add slow test * add slow test * Apply suggestions from code review * add slow test * add slow test
1 parent e5f232f commit b978334

File tree

17 files changed

+2110
-1
lines changed

17 files changed

+2110
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@
278278
title: Kandinsky 2.1
279279
- local: api/pipelines/kandinsky_v22
280280
title: Kandinsky 2.2
281+
- local: api/pipelines/kandinsky3
282+
title: Kandinsky 3
281283
- local: api/pipelines/latent_consistency_models
282284
title: Latent Consistency Models
283285
- local: api/pipelines/latent_diffusion
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
3+
the License. You may obtain a copy of the License at
4+
http://www.apache.org/licenses/LICENSE-2.0
5+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
6+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
7+
specific language governing permissions and limitations under the License.
8+
-->
9+
10+
# Kandinsky 3
11+
12+
TODO
13+
14+
## Kandinsky3Pipeline
15+
16+
[[autodoc]] Kandinsky3Pipeline
17+
- all
18+
- __call__
19+
20+
## Kandinsky3Img2ImgPipeline
21+
22+
[[autodoc]] Kandinsky3Img2ImgPipeline
23+
- all
24+
- __call__

scripts/convert_kandinsky3_unet.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import fnmatch
4+
5+
from safetensors.torch import load_file
6+
7+
from diffusers import Kandinsky3UNet
8+
9+
10+
MAPPING = {
11+
"to_time_embed.1": "time_embedding.linear_1",
12+
"to_time_embed.3": "time_embedding.linear_2",
13+
"in_layer": "conv_in",
14+
"out_layer.0": "conv_norm_out",
15+
"out_layer.2": "conv_out",
16+
"down_samples": "down_blocks",
17+
"up_samples": "up_blocks",
18+
"projection_lin": "encoder_hid_proj.projection_linear",
19+
"projection_ln": "encoder_hid_proj.projection_norm",
20+
"feature_pooling": "add_time_condition",
21+
"to_query": "to_q",
22+
"to_key": "to_k",
23+
"to_value": "to_v",
24+
"output_layer": "to_out.0",
25+
"self_attention_block": "attentions.0",
26+
}
27+
28+
DYNAMIC_MAP = {
29+
"resnet_attn_blocks.*.0": "resnets_in.*",
30+
"resnet_attn_blocks.*.1": ("attentions.*", 1),
31+
"resnet_attn_blocks.*.2": "resnets_out.*",
32+
}
33+
# MAPPING = {}
34+
35+
36+
def convert_state_dict(unet_state_dict):
37+
"""
38+
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
39+
Args:
40+
unet_model (torch.nn.Module): The original U-Net model.
41+
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
42+
43+
Returns:
44+
OrderedDict: The converted state dictionary.
45+
"""
46+
# Example of renaming logic (this will vary based on your model's architecture)
47+
converted_state_dict = {}
48+
for key in unet_state_dict:
49+
new_key = key
50+
for pattern, new_pattern in MAPPING.items():
51+
new_key = new_key.replace(pattern, new_pattern)
52+
53+
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
54+
has_matched = False
55+
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
56+
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
57+
58+
if isinstance(dyn_new_pattern, tuple):
59+
new_star = star + dyn_new_pattern[-1]
60+
dyn_new_pattern = dyn_new_pattern[0]
61+
else:
62+
new_star = star
63+
64+
pattern = dyn_pattern.replace("*", str(star))
65+
new_pattern = dyn_new_pattern.replace("*", str(new_star))
66+
67+
new_key = new_key.replace(pattern, new_pattern)
68+
has_matched = True
69+
70+
converted_state_dict[new_key] = unet_state_dict[key]
71+
72+
return converted_state_dict
73+
74+
75+
def main(model_path, output_path):
76+
# Load your original U-Net model
77+
unet_state_dict = load_file(model_path)
78+
79+
# Initialize your Kandinsky3UNet model
80+
config = {}
81+
82+
# Convert the state dict
83+
converted_state_dict = convert_state_dict(unet_state_dict)
84+
85+
unet = Kandinsky3UNet(config)
86+
unet.load_state_dict(converted_state_dict)
87+
88+
unet.save_pretrained(output_path)
89+
print(f"Converted model saved to {output_path}")
90+
91+
92+
if __name__ == "__main__":
93+
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
94+
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
95+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
96+
97+
args = parser.parse_args()
98+
main(args.model_path, args.output_path)

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"AutoencoderTiny",
8080
"ConsistencyDecoderVAE",
8181
"ControlNetModel",
82+
"Kandinsky3UNet",
8283
"ModelMixin",
8384
"MotionAdapter",
8485
"MultiAdapter",
@@ -214,6 +215,8 @@
214215
"IFPipeline",
215216
"IFSuperResolutionPipeline",
216217
"ImageTextPipelineOutput",
218+
"Kandinsky3Img2ImgPipeline",
219+
"Kandinsky3Pipeline",
217220
"KandinskyCombinedPipeline",
218221
"KandinskyImg2ImgCombinedPipeline",
219222
"KandinskyImg2ImgPipeline",
@@ -446,6 +449,7 @@
446449
AutoencoderTiny,
447450
ConsistencyDecoderVAE,
448451
ControlNetModel,
452+
Kandinsky3UNet,
449453
ModelMixin,
450454
MotionAdapter,
451455
MultiAdapter,
@@ -560,6 +564,8 @@
560564
IFPipeline,
561565
IFSuperResolutionPipeline,
562566
ImageTextPipelineOutput,
567+
Kandinsky3Img2ImgPipeline,
568+
Kandinsky3Pipeline,
563569
KandinskyCombinedPipeline,
564570
KandinskyImg2ImgCombinedPipeline,
565571
KandinskyImg2ImgPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_import_structure["unet_2d"] = ["UNet2DModel"]
3737
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
3838
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
39+
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
3940
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
4041
_import_structure["vq_model"] = ["VQModel"]
4142

@@ -63,6 +64,7 @@
6364
from .unet_2d import UNet2DModel
6465
from .unet_2d_condition import UNet2DConditionModel
6566
from .unet_3d_condition import UNet3DConditionModel
67+
from .unet_kandi3 import Kandinsky3UNet
6668
from .unet_motion_model import MotionAdapter, UNetMotionModel
6769
from .vq_model import VQModel
6870

src/diffusers/models/attention_processor.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818
import torch.nn.functional as F
19-
from torch import nn
19+
from torch import einsum, nn
2020

2121
from ..utils import USE_PEFT_BACKEND, deprecate, logging
2222
from ..utils.import_utils import is_xformers_available
@@ -2219,6 +2219,44 @@ def __call__(
22192219
return hidden_states
22202220

22212221

2222+
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
2223+
# this way torch.compile and co. will work as well
2224+
class Kandi3AttnProcessor:
2225+
r"""
2226+
Default kandinsky3 proccesor for performing attention-related computations.
2227+
"""
2228+
2229+
@staticmethod
2230+
def _reshape(hid_states, h):
2231+
b, n, f = hid_states.shape
2232+
d = f // h
2233+
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
2234+
2235+
def __call__(
2236+
self,
2237+
attn,
2238+
x,
2239+
context,
2240+
context_mask=None,
2241+
):
2242+
query = self._reshape(attn.to_q(x), h=attn.num_heads)
2243+
key = self._reshape(attn.to_k(context), h=attn.num_heads)
2244+
value = self._reshape(attn.to_v(context), h=attn.num_heads)
2245+
2246+
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
2247+
2248+
if context_mask is not None:
2249+
max_neg_value = -torch.finfo(attention_matrix.dtype).max
2250+
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
2251+
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
2252+
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
2253+
2254+
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
2255+
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
2256+
out = attn.to_out[0](out)
2257+
return out
2258+
2259+
22222260
LORA_ATTENTION_PROCESSORS = (
22232261
LoRAAttnProcessor,
22242262
LoRAAttnProcessor2_0,
@@ -2244,6 +2282,7 @@ def __call__(
22442282
LoRAXFormersAttnProcessor,
22452283
IPAdapterAttnProcessor,
22462284
IPAdapterAttnProcessor2_0,
2285+
Kandi3AttnProcessor,
22472286
)
22482287

22492288
AttentionProcessor = Union[

0 commit comments

Comments
 (0)