Skip to content

Commit d144c46

Browse files
authored
[UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442)
* pass norm_num_groups to unet blocs and attention * fix UNet2DConditionModel * add norm_num_groups arg in vae * add tests * remove comment * Apply suggestions from code review
1 parent b34be03 commit d144c46

File tree

7 files changed

+59
-7
lines changed

7 files changed

+59
-7
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ def __init__(
113113
d_head: int,
114114
depth: int = 1,
115115
dropout: float = 0.0,
116+
num_groups: int = 32,
116117
context_dim: Optional[int] = None,
117118
):
118119
super().__init__()
119120
self.n_heads = n_heads
120121
self.d_head = d_head
121122
self.in_channels = in_channels
122123
inner_dim = n_heads * d_head
123-
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
124+
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
124125

125126
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
126127

src/diffusers/models/unet_2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def __init__(
114114
add_downsample=not is_final_block,
115115
resnet_eps=norm_eps,
116116
resnet_act_fn=act_fn,
117+
resnet_groups=norm_num_groups,
117118
attn_num_head_channels=attention_head_dim,
118119
downsample_padding=downsample_padding,
119120
)
@@ -151,6 +152,7 @@ def __init__(
151152
add_upsample=not is_final_block,
152153
resnet_eps=norm_eps,
153154
resnet_act_fn=act_fn,
155+
resnet_groups=norm_num_groups,
154156
attn_num_head_channels=attention_head_dim,
155157
)
156158
self.up_blocks.append(up_block)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def __init__(
114114
add_downsample=not is_final_block,
115115
resnet_eps=norm_eps,
116116
resnet_act_fn=act_fn,
117+
resnet_groups=norm_num_groups,
117118
cross_attention_dim=cross_attention_dim,
118119
attn_num_head_channels=attention_head_dim,
119120
downsample_padding=downsample_padding,
@@ -153,6 +154,7 @@ def __init__(
153154
add_upsample=not is_final_block,
154155
resnet_eps=norm_eps,
155156
resnet_act_fn=act_fn,
157+
resnet_groups=norm_num_groups,
156158
cross_attention_dim=cross_attention_dim,
157159
attn_num_head_channels=attention_head_dim,
158160
)

src/diffusers/models/unet_blocks.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_down_block(
3131
resnet_eps,
3232
resnet_act_fn,
3333
attn_num_head_channels,
34+
resnet_groups=None,
3435
cross_attention_dim=None,
3536
downsample_padding=None,
3637
):
@@ -44,6 +45,7 @@ def get_down_block(
4445
add_downsample=add_downsample,
4546
resnet_eps=resnet_eps,
4647
resnet_act_fn=resnet_act_fn,
48+
resnet_groups=resnet_groups,
4749
downsample_padding=downsample_padding,
4850
)
4951
elif down_block_type == "AttnDownBlock2D":
@@ -55,6 +57,7 @@ def get_down_block(
5557
add_downsample=add_downsample,
5658
resnet_eps=resnet_eps,
5759
resnet_act_fn=resnet_act_fn,
60+
resnet_groups=resnet_groups,
5861
downsample_padding=downsample_padding,
5962
attn_num_head_channels=attn_num_head_channels,
6063
)
@@ -69,6 +72,7 @@ def get_down_block(
6972
add_downsample=add_downsample,
7073
resnet_eps=resnet_eps,
7174
resnet_act_fn=resnet_act_fn,
75+
resnet_groups=resnet_groups,
7276
downsample_padding=downsample_padding,
7377
cross_attention_dim=cross_attention_dim,
7478
attn_num_head_channels=attn_num_head_channels,
@@ -104,6 +108,7 @@ def get_down_block(
104108
add_downsample=add_downsample,
105109
resnet_eps=resnet_eps,
106110
resnet_act_fn=resnet_act_fn,
111+
resnet_groups=resnet_groups,
107112
downsample_padding=downsample_padding,
108113
)
109114

@@ -119,6 +124,7 @@ def get_up_block(
119124
resnet_eps,
120125
resnet_act_fn,
121126
attn_num_head_channels,
127+
resnet_groups=None,
122128
cross_attention_dim=None,
123129
):
124130
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
@@ -132,6 +138,7 @@ def get_up_block(
132138
add_upsample=add_upsample,
133139
resnet_eps=resnet_eps,
134140
resnet_act_fn=resnet_act_fn,
141+
resnet_groups=resnet_groups,
135142
)
136143
elif up_block_type == "CrossAttnUpBlock2D":
137144
if cross_attention_dim is None:
@@ -145,6 +152,7 @@ def get_up_block(
145152
add_upsample=add_upsample,
146153
resnet_eps=resnet_eps,
147154
resnet_act_fn=resnet_act_fn,
155+
resnet_groups=resnet_groups,
148156
cross_attention_dim=cross_attention_dim,
149157
attn_num_head_channels=attn_num_head_channels,
150158
)
@@ -158,6 +166,7 @@ def get_up_block(
158166
add_upsample=add_upsample,
159167
resnet_eps=resnet_eps,
160168
resnet_act_fn=resnet_act_fn,
169+
resnet_groups=resnet_groups,
161170
attn_num_head_channels=attn_num_head_channels,
162171
)
163172
elif up_block_type == "SkipUpBlock2D":
@@ -191,6 +200,7 @@ def get_up_block(
191200
add_upsample=add_upsample,
192201
resnet_eps=resnet_eps,
193202
resnet_act_fn=resnet_act_fn,
203+
resnet_groups=resnet_groups,
194204
)
195205
raise ValueError(f"{up_block_type} does not exist.")
196206

@@ -323,6 +333,7 @@ def __init__(
323333
in_channels // attn_num_head_channels,
324334
depth=1,
325335
context_dim=cross_attention_dim,
336+
num_groups=resnet_groups,
326337
)
327338
)
328339
resnets.append(
@@ -414,6 +425,7 @@ def __init__(
414425
num_head_channels=attn_num_head_channels,
415426
rescale_output_factor=output_scale_factor,
416427
eps=resnet_eps,
428+
num_groups=resnet_groups,
417429
)
418430
)
419431

@@ -498,6 +510,7 @@ def __init__(
498510
out_channels // attn_num_head_channels,
499511
depth=1,
500512
context_dim=cross_attention_dim,
513+
num_groups=resnet_groups,
501514
)
502515
)
503516
self.attentions = nn.ModuleList(attentions)
@@ -966,6 +979,7 @@ def __init__(
966979
num_head_channels=attn_num_head_channels,
967980
rescale_output_factor=output_scale_factor,
968981
eps=resnet_eps,
982+
num_groups=resnet_groups,
969983
)
970984
)
971985

@@ -1047,6 +1061,7 @@ def __init__(
10471061
out_channels // attn_num_head_channels,
10481062
depth=1,
10491063
context_dim=cross_attention_dim,
1064+
num_groups=resnet_groups,
10501065
)
10511066
)
10521067
self.attentions = nn.ModuleList(attentions)

src/diffusers/models/vae.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
down_block_types=("DownEncoderBlock2D",),
6060
block_out_channels=(64,),
6161
layers_per_block=2,
62+
norm_num_groups=32,
6263
act_fn="silu",
6364
double_z=True,
6465
):
@@ -86,6 +87,7 @@ def __init__(
8687
resnet_eps=1e-6,
8788
downsample_padding=0,
8889
resnet_act_fn=act_fn,
90+
resnet_groups=norm_num_groups,
8991
attn_num_head_channels=None,
9092
temb_channels=None,
9193
)
@@ -99,13 +101,12 @@ def __init__(
99101
output_scale_factor=1,
100102
resnet_time_scale_shift="default",
101103
attn_num_head_channels=None,
102-
resnet_groups=32,
104+
resnet_groups=norm_num_groups,
103105
temb_channels=None,
104106
)
105107

106108
# out
107-
num_groups_out = 32
108-
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
109+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
109110
self.conv_act = nn.SiLU()
110111

111112
conv_out_channels = 2 * out_channels if double_z else out_channels
@@ -138,6 +139,7 @@ def __init__(
138139
up_block_types=("UpDecoderBlock2D",),
139140
block_out_channels=(64,),
140141
layers_per_block=2,
142+
norm_num_groups=32,
141143
act_fn="silu",
142144
):
143145
super().__init__()
@@ -156,7 +158,7 @@ def __init__(
156158
output_scale_factor=1,
157159
resnet_time_scale_shift="default",
158160
attn_num_head_channels=None,
159-
resnet_groups=32,
161+
resnet_groups=norm_num_groups,
160162
temb_channels=None,
161163
)
162164

@@ -178,15 +180,15 @@ def __init__(
178180
add_upsample=not is_final_block,
179181
resnet_eps=1e-6,
180182
resnet_act_fn=act_fn,
183+
resnet_groups=norm_num_groups,
181184
attn_num_head_channels=None,
182185
temb_channels=None,
183186
)
184187
self.up_blocks.append(up_block)
185188
prev_output_channel = output_channel
186189

187190
# out
188-
num_groups_out = 32
189-
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
191+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
190192
self.conv_act = nn.SiLU()
191193
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
192194

@@ -405,6 +407,7 @@ def __init__(
405407
latent_channels: int = 3,
406408
sample_size: int = 32,
407409
num_vq_embeddings: int = 256,
410+
norm_num_groups: int = 32,
408411
):
409412
super().__init__()
410413

@@ -416,6 +419,7 @@ def __init__(
416419
block_out_channels=block_out_channels,
417420
layers_per_block=layers_per_block,
418421
act_fn=act_fn,
422+
norm_num_groups=norm_num_groups,
419423
double_z=False,
420424
)
421425

@@ -433,6 +437,7 @@ def __init__(
433437
block_out_channels=block_out_channels,
434438
layers_per_block=layers_per_block,
435439
act_fn=act_fn,
440+
norm_num_groups=norm_num_groups,
436441
)
437442

438443
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
@@ -509,6 +514,7 @@ def __init__(
509514
layers_per_block: int = 1,
510515
act_fn: str = "silu",
511516
latent_channels: int = 4,
517+
norm_num_groups: int = 32,
512518
sample_size: int = 32,
513519
):
514520
super().__init__()
@@ -521,6 +527,7 @@ def __init__(
521527
block_out_channels=block_out_channels,
522528
layers_per_block=layers_per_block,
523529
act_fn=act_fn,
530+
norm_num_groups=norm_num_groups,
524531
double_z=True,
525532
)
526533

@@ -531,6 +538,7 @@ def __init__(
531538
up_block_types=up_block_types,
532539
block_out_channels=block_out_channels,
533540
layers_per_block=layers_per_block,
541+
norm_num_groups=norm_num_groups,
534542
act_fn=act_fn,
535543
)
536544

tests/test_modeling_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ def test_output(self):
9999
expected_shape = inputs_dict["sample"].shape
100100
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
101101

102+
def test_forward_with_norm_groups(self):
103+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
104+
105+
init_dict["norm_num_groups"] = 16
106+
init_dict["block_out_channels"] = (16, 32)
107+
108+
model = self.model_class(**init_dict)
109+
model.to(torch_device)
110+
model.eval()
111+
112+
with torch.no_grad():
113+
output = model(**inputs_dict)
114+
115+
if isinstance(output, dict):
116+
output = output.sample
117+
118+
self.assertIsNotNone(output)
119+
expected_shape = inputs_dict["sample"].shape
120+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
121+
102122
def test_forward_signature(self):
103123
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
104124

tests/test_models_unet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,7 @@ def test_output_pretrained_ve_large(self):
293293
# fmt: on
294294

295295
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
296+
297+
def test_forward_with_norm_groups(self):
298+
# not required for this model
299+
pass

0 commit comments

Comments
 (0)