Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ def __init__(
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
)
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
):
Expand All @@ -44,6 +45,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownBlock2D":
Expand All @@ -55,6 +57,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
Expand All @@ -69,6 +72,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
Expand Down Expand Up @@ -104,6 +108,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)

Expand All @@ -119,6 +124,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
Expand All @@ -132,6 +138,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
Expand All @@ -145,6 +152,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
Expand All @@ -158,6 +166,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "SkipUpBlock2D":
Expand Down Expand Up @@ -191,6 +200,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
raise ValueError(f"{up_block_type} does not exist.")

Expand Down Expand Up @@ -323,6 +333,7 @@ def __init__(
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
resnets.append(
Expand Down Expand Up @@ -414,6 +425,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -498,6 +510,7 @@ def __init__(
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down Expand Up @@ -966,6 +979,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -1047,6 +1061,7 @@ def __init__(
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down
20 changes: 14 additions & 6 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
Expand All @@ -99,13 +101,12 @@ def __init__(
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)

# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()

conv_out_channels = 2 * out_channels if double_z else out_channels
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
):
super().__init__()
Expand All @@ -156,7 +158,7 @@ def __init__(
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)

Expand All @@ -178,15 +180,15 @@ def __init__(
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

Expand Down Expand Up @@ -405,6 +407,7 @@ def __init__(
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
):
super().__init__()

Expand All @@ -416,6 +419,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)

Expand All @@ -433,6 +437,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)

def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
Expand Down Expand Up @@ -509,6 +514,7 @@ def __init__(
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
Expand All @@ -521,6 +527,7 @@ def __init__(
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)

Expand All @@ -531,6 +538,7 @@ def __init__(
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)

Expand Down
20 changes: 20 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def test_output(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

with torch.no_grad():
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()

Expand Down
4 changes: 4 additions & 0 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,7 @@ def test_output_pretrained_ve_large(self):
# fmt: on

self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

def test_forward_with_norm_groups(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to adjust norm groups inside the blocks cf https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_blocks.py#L780
So not sure if it can work with any value because of the division with 4

# not required for this model
pass