@@ -249,6 +249,81 @@ def get_down_block(
249249 raise ValueError (f"{ down_block_type } does not exist." )
250250
251251
252+ def get_mid_block (
253+ mid_block_type : str ,
254+ temb_channels : int ,
255+ in_channels : int ,
256+ resnet_eps : float ,
257+ resnet_act_fn : str ,
258+ resnet_groups : int ,
259+ output_scale_factor : float = 1.0 ,
260+ transformer_layers_per_block : int = 1 ,
261+ num_attention_heads : Optional [int ] = None ,
262+ cross_attention_dim : Optional [int ] = None ,
263+ dual_cross_attention : bool = False ,
264+ use_linear_projection : bool = False ,
265+ mid_block_only_cross_attention : bool = False ,
266+ upcast_attention : bool = False ,
267+ resnet_time_scale_shift : str = "default" ,
268+ attention_type : str = "default" ,
269+ resnet_skip_time_act : bool = False ,
270+ cross_attention_norm : Optional [str ] = None ,
271+ attention_head_dim : Optional [int ] = 1 ,
272+ dropout : float = 0.0 ,
273+ ):
274+ if mid_block_type == "UNetMidBlock2DCrossAttn" :
275+ return UNetMidBlock2DCrossAttn (
276+ transformer_layers_per_block = transformer_layers_per_block ,
277+ in_channels = in_channels ,
278+ temb_channels = temb_channels ,
279+ dropout = dropout ,
280+ resnet_eps = resnet_eps ,
281+ resnet_act_fn = resnet_act_fn ,
282+ output_scale_factor = output_scale_factor ,
283+ resnet_time_scale_shift = resnet_time_scale_shift ,
284+ cross_attention_dim = cross_attention_dim ,
285+ num_attention_heads = num_attention_heads ,
286+ resnet_groups = resnet_groups ,
287+ dual_cross_attention = dual_cross_attention ,
288+ use_linear_projection = use_linear_projection ,
289+ upcast_attention = upcast_attention ,
290+ attention_type = attention_type ,
291+ )
292+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn" :
293+ return UNetMidBlock2DSimpleCrossAttn (
294+ in_channels = in_channels ,
295+ temb_channels = temb_channels ,
296+ dropout = dropout ,
297+ resnet_eps = resnet_eps ,
298+ resnet_act_fn = resnet_act_fn ,
299+ output_scale_factor = output_scale_factor ,
300+ cross_attention_dim = cross_attention_dim ,
301+ attention_head_dim = attention_head_dim ,
302+ resnet_groups = resnet_groups ,
303+ resnet_time_scale_shift = resnet_time_scale_shift ,
304+ skip_time_act = resnet_skip_time_act ,
305+ only_cross_attention = mid_block_only_cross_attention ,
306+ cross_attention_norm = cross_attention_norm ,
307+ )
308+ elif mid_block_type == "UNetMidBlock2D" :
309+ return UNetMidBlock2D (
310+ in_channels = in_channels ,
311+ temb_channels = temb_channels ,
312+ dropout = dropout ,
313+ num_layers = 0 ,
314+ resnet_eps = resnet_eps ,
315+ resnet_act_fn = resnet_act_fn ,
316+ output_scale_factor = output_scale_factor ,
317+ resnet_groups = resnet_groups ,
318+ resnet_time_scale_shift = resnet_time_scale_shift ,
319+ add_attention = False ,
320+ )
321+ elif mid_block_type is None :
322+ return None
323+ else :
324+ raise ValueError (f"unknown mid_block_type : { mid_block_type } " )
325+
326+
252327def get_up_block (
253328 up_block_type : str ,
254329 num_layers : int ,
0 commit comments