@@ -560,7 +560,8 @@ def forward(
560560 hidden_states ,
561561 encoder_hidden_states = encoder_hidden_states ,
562562 cross_attention_kwargs = cross_attention_kwargs ,
563- ).sample
563+ return_dict = False ,
564+ )[0 ]
564565 hidden_states = resnet (hidden_states , temb )
565566
566567 return hidden_states
@@ -868,15 +869,16 @@ def custom_forward(*inputs):
868869 hidden_states ,
869870 encoder_hidden_states = encoder_hidden_states ,
870871 cross_attention_kwargs = cross_attention_kwargs ,
871- ).sample
872+ return_dict = False ,
873+ )[0 ]
872874
873- output_states += (hidden_states ,)
875+ output_states = output_states + (hidden_states ,)
874876
875877 if self .downsamplers is not None :
876878 for downsampler in self .downsamplers :
877879 hidden_states = downsampler (hidden_states )
878880
879- output_states += (hidden_states ,)
881+ output_states = output_states + (hidden_states ,)
880882
881883 return hidden_states , output_states
882884
@@ -949,13 +951,13 @@ def custom_forward(*inputs):
949951 else :
950952 hidden_states = resnet (hidden_states , temb )
951953
952- output_states += (hidden_states ,)
954+ output_states = output_states + (hidden_states ,)
953955
954956 if self .downsamplers is not None :
955957 for downsampler in self .downsamplers :
956958 hidden_states = downsampler (hidden_states )
957959
958- output_states += (hidden_states ,)
960+ output_states = output_states + (hidden_states ,)
959961
960962 return hidden_states , output_states
961963
@@ -1342,13 +1344,13 @@ def custom_forward(*inputs):
13421344 else :
13431345 hidden_states = resnet (hidden_states , temb )
13441346
1345- output_states += (hidden_states ,)
1347+ output_states = output_states + (hidden_states ,)
13461348
13471349 if self .downsamplers is not None :
13481350 for downsampler in self .downsamplers :
13491351 hidden_states = downsampler (hidden_states , temb )
13501352
1351- output_states += (hidden_states ,)
1353+ output_states = output_states + (hidden_states ,)
13521354
13531355 return hidden_states , output_states
13541356
@@ -1466,13 +1468,13 @@ def forward(
14661468 ** cross_attention_kwargs ,
14671469 )
14681470
1469- output_states += (hidden_states ,)
1471+ output_states = output_states + (hidden_states ,)
14701472
14711473 if self .downsamplers is not None :
14721474 for downsampler in self .downsamplers :
14731475 hidden_states = downsampler (hidden_states , temb )
14741476
1475- output_states += (hidden_states ,)
1477+ output_states = output_states + (hidden_states ,)
14761478
14771479 return hidden_states , output_states
14781480
@@ -1859,7 +1861,8 @@ def custom_forward(*inputs):
18591861 hidden_states ,
18601862 encoder_hidden_states = encoder_hidden_states ,
18611863 cross_attention_kwargs = cross_attention_kwargs ,
1862- ).sample
1864+ return_dict = False ,
1865+ )[0 ]
18631866
18641867 if self .upsamplers is not None :
18651868 for upsampler in self .upsamplers :
0 commit comments