@@ -382,6 +382,20 @@ def forward(
382382 If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
383383 returned, otherwise a `tuple` is returned where the first element is the sample tensor.
384384 """
385+ # By default samples have to be AT least a multiple of the overall upsampling factor.
386+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
387+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
388+ # on the fly if necessary.
389+ default_overall_up_factor = 2 ** self .num_upsamplers
390+
391+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
392+ forward_upsample_size = False
393+ upsample_size = None
394+
395+ if any (s % default_overall_up_factor != 0 for s in sample .shape [- 2 :]):
396+ logger .info ("Forward upsample size to force interpolation output size." )
397+ forward_upsample_size = True
398+
385399 # 1. time
386400 timesteps = timestep
387401 if not torch .is_tensor (timesteps ):
@@ -457,22 +471,31 @@ def forward(
457471
458472 # 5. up
459473 for i , upsample_block in enumerate (self .up_blocks ):
474+ is_final_block = i == len (self .up_blocks ) - 1
475+
460476 res_samples = down_block_res_samples [- len (upsample_block .resnets ) :]
461477 down_block_res_samples = down_block_res_samples [: - len (upsample_block .resnets )]
462478
479+ # if we have not reached the final block and need to forward the
480+ # upsample size, we do it here
481+ if not is_final_block and forward_upsample_size :
482+ upsample_size = down_block_res_samples [- 1 ].shape [2 :]
483+
463484 if hasattr (upsample_block , "has_cross_attention" ) and upsample_block .has_cross_attention :
464485 sample = upsample_block (
465486 hidden_states = sample ,
466487 temb = emb ,
467488 res_hidden_states_tuple = res_samples ,
468489 encoder_hidden_states = encoder_hidden_states ,
490+ upsample_size = upsample_size ,
469491 image_only_indicator = image_only_indicator ,
470492 )
471493 else :
472494 sample = upsample_block (
473495 hidden_states = sample ,
474496 temb = emb ,
475497 res_hidden_states_tuple = res_samples ,
498+ upsample_size = upsample_size ,
476499 image_only_indicator = image_only_indicator ,
477500 )
478501
0 commit comments