@@ -31,6 +31,7 @@ def get_down_block(
3131 resnet_eps ,
3232 resnet_act_fn ,
3333 attn_num_head_channels ,
34+ resnet_groups = None ,
3435 cross_attention_dim = None ,
3536 downsample_padding = None ,
3637):
@@ -44,6 +45,7 @@ def get_down_block(
4445 add_downsample = add_downsample ,
4546 resnet_eps = resnet_eps ,
4647 resnet_act_fn = resnet_act_fn ,
48+ resnet_groups = resnet_groups ,
4749 downsample_padding = downsample_padding ,
4850 )
4951 elif down_block_type == "AttnDownBlock2D" :
@@ -55,6 +57,7 @@ def get_down_block(
5557 add_downsample = add_downsample ,
5658 resnet_eps = resnet_eps ,
5759 resnet_act_fn = resnet_act_fn ,
60+ resnet_groups = resnet_groups ,
5861 downsample_padding = downsample_padding ,
5962 attn_num_head_channels = attn_num_head_channels ,
6063 )
@@ -69,6 +72,7 @@ def get_down_block(
6972 add_downsample = add_downsample ,
7073 resnet_eps = resnet_eps ,
7174 resnet_act_fn = resnet_act_fn ,
75+ resnet_groups = resnet_groups ,
7276 downsample_padding = downsample_padding ,
7377 cross_attention_dim = cross_attention_dim ,
7478 attn_num_head_channels = attn_num_head_channels ,
@@ -104,6 +108,7 @@ def get_down_block(
104108 add_downsample = add_downsample ,
105109 resnet_eps = resnet_eps ,
106110 resnet_act_fn = resnet_act_fn ,
111+ resnet_groups = resnet_groups ,
107112 downsample_padding = downsample_padding ,
108113 )
109114
@@ -119,6 +124,7 @@ def get_up_block(
119124 resnet_eps ,
120125 resnet_act_fn ,
121126 attn_num_head_channels ,
127+ resnet_groups = None ,
122128 cross_attention_dim = None ,
123129):
124130 up_block_type = up_block_type [7 :] if up_block_type .startswith ("UNetRes" ) else up_block_type
@@ -132,6 +138,7 @@ def get_up_block(
132138 add_upsample = add_upsample ,
133139 resnet_eps = resnet_eps ,
134140 resnet_act_fn = resnet_act_fn ,
141+ resnet_groups = resnet_groups ,
135142 )
136143 elif up_block_type == "CrossAttnUpBlock2D" :
137144 if cross_attention_dim is None :
@@ -145,6 +152,7 @@ def get_up_block(
145152 add_upsample = add_upsample ,
146153 resnet_eps = resnet_eps ,
147154 resnet_act_fn = resnet_act_fn ,
155+ resnet_groups = resnet_groups ,
148156 cross_attention_dim = cross_attention_dim ,
149157 attn_num_head_channels = attn_num_head_channels ,
150158 )
@@ -158,6 +166,7 @@ def get_up_block(
158166 add_upsample = add_upsample ,
159167 resnet_eps = resnet_eps ,
160168 resnet_act_fn = resnet_act_fn ,
169+ resnet_groups = resnet_groups ,
161170 attn_num_head_channels = attn_num_head_channels ,
162171 )
163172 elif up_block_type == "SkipUpBlock2D" :
@@ -191,6 +200,7 @@ def get_up_block(
191200 add_upsample = add_upsample ,
192201 resnet_eps = resnet_eps ,
193202 resnet_act_fn = resnet_act_fn ,
203+ resnet_groups = resnet_groups ,
194204 )
195205 raise ValueError (f"{ up_block_type } does not exist." )
196206
@@ -323,6 +333,7 @@ def __init__(
323333 in_channels // attn_num_head_channels ,
324334 depth = 1 ,
325335 context_dim = cross_attention_dim ,
336+ num_groups = resnet_groups ,
326337 )
327338 )
328339 resnets .append (
@@ -414,6 +425,7 @@ def __init__(
414425 num_head_channels = attn_num_head_channels ,
415426 rescale_output_factor = output_scale_factor ,
416427 eps = resnet_eps ,
428+ num_groups = resnet_groups ,
417429 )
418430 )
419431
@@ -498,6 +510,7 @@ def __init__(
498510 out_channels // attn_num_head_channels ,
499511 depth = 1 ,
500512 context_dim = cross_attention_dim ,
513+ num_groups = resnet_groups ,
501514 )
502515 )
503516 self .attentions = nn .ModuleList (attentions )
@@ -966,6 +979,7 @@ def __init__(
966979 num_head_channels = attn_num_head_channels ,
967980 rescale_output_factor = output_scale_factor ,
968981 eps = resnet_eps ,
982+ num_groups = resnet_groups ,
969983 )
970984 )
971985
@@ -1047,6 +1061,7 @@ def __init__(
10471061 out_channels // attn_num_head_channels ,
10481062 depth = 1 ,
10491063 context_dim = cross_attention_dim ,
1064+ num_groups = resnet_groups ,
10501065 )
10511066 )
10521067 self .attentions = nn .ModuleList (attentions )
0 commit comments