Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ def __init__(
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
)
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
):
Expand All @@ -44,6 +45,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownBlock2D":
Expand All @@ -55,6 +57,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
Expand All @@ -69,6 +72,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
Expand Down Expand Up @@ -104,6 +108,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)

Expand All @@ -119,6 +124,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
Expand All @@ -132,6 +138,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
Expand All @@ -145,6 +152,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
Expand All @@ -158,6 +166,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "SkipUpBlock2D":
Expand Down Expand Up @@ -191,6 +200,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
raise ValueError(f"{up_block_type} does not exist.")

Expand Down Expand Up @@ -323,6 +333,7 @@ def __init__(
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
resnets.append(
Expand Down Expand Up @@ -414,6 +425,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -498,6 +510,7 @@ def __init__(
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down Expand Up @@ -966,6 +979,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -1047,6 +1061,7 @@ def __init__(
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down
20 changes: 14 additions & 6 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
Expand All @@ -99,13 +101,12 @@ def __init__(
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)

# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()

conv_out_channels = 2 * out_channels if double_z else out_channels
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
):
super().__init__()
Expand All @@ -156,7 +158,7 @@ def __init__(
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)

Expand All @@ -178,15 +180,15 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

Expand Down Expand Up @@ -405,6 +407,7 @@ def __init__(
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
):
super().__init__()

Expand All @@ -416,6 +419,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)

Expand All @@ -433,6 +437,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)

def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
Expand Down Expand Up @@ -509,6 +514,7 @@ def __init__(
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
Expand All @@ -521,6 +527,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)

Expand All @@ -531,6 +538,7 @@ def __init__(
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)

Expand Down
20 changes: 20 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def test_output(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

with torch.no_grad():
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()

Expand Down
7 changes: 6 additions & 1 deletion tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def prepare_init_args_and_inputs_for_common(self):
"in_channels": 4,
"out_channels": 4,
"layers_per_block": 2,
"block_out_channels": (32, 64),
"block_out_channels": (16, 32),
"norm_num_groups": 16,
"attention_head_dim": 32,
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
Expand Down Expand Up @@ -291,3 +292,7 @@ def test_output_pretrained_ve_large(self):
# fmt: on

self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

def test_forward_with_norm_groups(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to adjust norm groups inside the blocks cf https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_blocks.py#L780
So not sure if it can work with any value because of the division with 4

# not required for this model
pass