Skip to content

Commit 62154bd

Browse files
[@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (huggingface#5913)
* finalize * finalize * finalize * add slow test * add slow test * add slow test * Fix more * add slow test * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * Better * Fix more * Fix more * add slow test * Add auto pipelines * add slow test * Add all * add slow test * add slow test * add slow test * add slow test * add slow test * Apply suggestions from code review * add slow test * add slow test
1 parent f3aa940 commit 62154bd

File tree

11 files changed

+1651
-1
lines changed

11 files changed

+1651
-1
lines changed

__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"AutoencoderTiny",
8080
"ConsistencyDecoderVAE",
8181
"ControlNetModel",
82+
"Kandinsky3UNet",
8283
"ModelMixin",
8384
"MotionAdapter",
8485
"MultiAdapter",
@@ -214,6 +215,8 @@
214215
"IFPipeline",
215216
"IFSuperResolutionPipeline",
216217
"ImageTextPipelineOutput",
218+
"Kandinsky3Img2ImgPipeline",
219+
"Kandinsky3Pipeline",
217220
"KandinskyCombinedPipeline",
218221
"KandinskyImg2ImgCombinedPipeline",
219222
"KandinskyImg2ImgPipeline",
@@ -446,6 +449,7 @@
446449
AutoencoderTiny,
447450
ConsistencyDecoderVAE,
448451
ControlNetModel,
452+
Kandinsky3UNet,
449453
ModelMixin,
450454
MotionAdapter,
451455
MultiAdapter,
@@ -560,6 +564,8 @@
560564
IFPipeline,
561565
IFSuperResolutionPipeline,
562566
ImageTextPipelineOutput,
567+
Kandinsky3Img2ImgPipeline,
568+
Kandinsky3Pipeline,
563569
KandinskyCombinedPipeline,
564570
KandinskyImg2ImgCombinedPipeline,
565571
KandinskyImg2ImgPipeline,

models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_import_structure["unet_2d"] = ["UNet2DModel"]
3737
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
3838
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
39+
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
3940
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
4041
_import_structure["vq_model"] = ["VQModel"]
4142

@@ -63,6 +64,7 @@
6364
from .unet_2d import UNet2DModel
6465
from .unet_2d_condition import UNet2DConditionModel
6566
from .unet_3d_condition import UNet3DConditionModel
67+
from .unet_kandi3 import Kandinsky3UNet
6668
from .unet_motion_model import MotionAdapter, UNetMotionModel
6769
from .vq_model import VQModel
6870

models/attention_processor.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818
import torch.nn.functional as F
19-
from torch import nn
19+
from torch import einsum, nn
2020

2121
from ..utils import USE_PEFT_BACKEND, deprecate, logging
2222
from ..utils.import_utils import is_xformers_available
@@ -2219,6 +2219,44 @@ def __call__(
22192219
return hidden_states
22202220

22212221

2222+
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
2223+
# this way torch.compile and co. will work as well
2224+
class Kandi3AttnProcessor:
2225+
r"""
2226+
Default kandinsky3 proccesor for performing attention-related computations.
2227+
"""
2228+
2229+
@staticmethod
2230+
def _reshape(hid_states, h):
2231+
b, n, f = hid_states.shape
2232+
d = f // h
2233+
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
2234+
2235+
def __call__(
2236+
self,
2237+
attn,
2238+
x,
2239+
context,
2240+
context_mask=None,
2241+
):
2242+
query = self._reshape(attn.to_q(x), h=attn.num_heads)
2243+
key = self._reshape(attn.to_k(context), h=attn.num_heads)
2244+
value = self._reshape(attn.to_v(context), h=attn.num_heads)
2245+
2246+
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
2247+
2248+
if context_mask is not None:
2249+
max_neg_value = -torch.finfo(attention_matrix.dtype).max
2250+
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
2251+
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
2252+
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
2253+
2254+
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
2255+
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
2256+
out = attn.to_out[0](out)
2257+
return out
2258+
2259+
22222260
LORA_ATTENTION_PROCESSORS = (
22232261
LoRAAttnProcessor,
22242262
LoRAAttnProcessor2_0,
@@ -2244,6 +2282,7 @@ def __call__(
22442282
LoRAXFormersAttnProcessor,
22452283
IPAdapterAttnProcessor,
22462284
IPAdapterAttnProcessor2_0,
2285+
Kandi3AttnProcessor,
22472286
)
22482287

22492288
AttentionProcessor = Union[

0 commit comments

Comments
 (0)