@@ -879,6 +879,9 @@ def __call__(
879879 scale : float = 1.0 ,
880880 ) -> torch .Tensor :
881881 residual = hidden_states
882+
883+ args = () if USE_PEFT_BACKEND else (scale ,)
884+
882885 hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
883886 batch_size , sequence_length , _ = hidden_states .shape
884887
@@ -891,17 +894,17 @@ def __call__(
891894
892895 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
893896
894- query = attn .to_q (hidden_states , scale = scale )
897+ query = attn .to_q (hidden_states , * args )
895898 query = attn .head_to_batch_dim (query )
896899
897- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , scale = scale )
898- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , scale = scale )
900+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , * args )
901+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , * args )
899902 encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
900903 encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
901904
902905 if not attn .only_cross_attention :
903- key = attn .to_k (hidden_states , scale = scale )
904- value = attn .to_v (hidden_states , scale = scale )
906+ key = attn .to_k (hidden_states , * args )
907+ value = attn .to_v (hidden_states , * args )
905908 key = attn .head_to_batch_dim (key )
906909 value = attn .head_to_batch_dim (value )
907910 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
@@ -915,7 +918,7 @@ def __call__(
915918 hidden_states = attn .batch_to_head_dim (hidden_states )
916919
917920 # linear proj
918- hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
921+ hidden_states = attn .to_out [0 ](hidden_states , * args )
919922 # dropout
920923 hidden_states = attn .to_out [1 ](hidden_states )
921924
@@ -946,6 +949,9 @@ def __call__(
946949 scale : float = 1.0 ,
947950 ) -> torch .Tensor :
948951 residual = hidden_states
952+
953+ args = () if USE_PEFT_BACKEND else (scale ,)
954+
949955 hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
950956 batch_size , sequence_length , _ = hidden_states .shape
951957
@@ -958,7 +964,7 @@ def __call__(
958964
959965 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
960966
961- query = attn .to_q (hidden_states , scale = scale )
967+ query = attn .to_q (hidden_states , * args )
962968 query = attn .head_to_batch_dim (query , out_dim = 4 )
963969
964970 encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
@@ -967,8 +973,8 @@ def __call__(
967973 encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj , out_dim = 4 )
968974
969975 if not attn .only_cross_attention :
970- key = attn .to_k (hidden_states , scale = scale )
971- value = attn .to_v (hidden_states , scale = scale )
976+ key = attn .to_k (hidden_states , * args )
977+ value = attn .to_v (hidden_states , * args )
972978 key = attn .head_to_batch_dim (key , out_dim = 4 )
973979 value = attn .head_to_batch_dim (value , out_dim = 4 )
974980 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
@@ -985,7 +991,7 @@ def __call__(
985991 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , residual .shape [1 ])
986992
987993 # linear proj
988- hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
994+ hidden_states = attn .to_out [0 ](hidden_states , * args )
989995 # dropout
990996 hidden_states = attn .to_out [1 ](hidden_states )
991997
@@ -1177,6 +1183,8 @@ def __call__(
11771183 ) -> torch .FloatTensor :
11781184 residual = hidden_states
11791185
1186+ args = () if USE_PEFT_BACKEND else (scale ,)
1187+
11801188 if attn .spatial_norm is not None :
11811189 hidden_states = attn .spatial_norm (hidden_states , temb )
11821190
@@ -1207,12 +1215,8 @@ def __call__(
12071215 elif attn .norm_cross :
12081216 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
12091217
1210- key = (
1211- attn .to_k (encoder_hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_k (encoder_hidden_states )
1212- )
1213- value = (
1214- attn .to_v (encoder_hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_v (encoder_hidden_states )
1215- )
1218+ key = attn .to_k (encoder_hidden_states , * args )
1219+ value = attn .to_v (encoder_hidden_states , * args )
12161220
12171221 inner_dim = key .shape [- 1 ]
12181222 head_dim = inner_dim // attn .heads
@@ -1232,9 +1236,7 @@ def __call__(
12321236 hidden_states = hidden_states .to (query .dtype )
12331237
12341238 # linear proj
1235- hidden_states = (
1236- attn .to_out [0 ](hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_out [0 ](hidden_states )
1237- )
1239+ hidden_states = attn .to_out [0 ](hidden_states , * args )
12381240 # dropout
12391241 hidden_states = attn .to_out [1 ](hidden_states )
12401242
0 commit comments