Skip to content

Commit 84616b5

Browse files
authored
Fix CrossAttention._sliced_attention (#563)
* Fix CrossAttention._sliced_attention Co-authored-by: ydshieh <[email protected]>
1 parent 8d36d5a commit 84616b5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/models/attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,15 @@ def reshape_batch_dim_to_heads(self, tensor):
249249
return tensor
250250

251251
def forward(self, hidden_states, context=None, mask=None):
252-
batch_size, sequence_length, dim = hidden_states.shape
252+
batch_size, sequence_length, _ = hidden_states.shape
253253

254254
query = self.to_q(hidden_states)
255255
context = context if context is not None else hidden_states
256256
key = self.to_k(context)
257257
value = self.to_v(context)
258258

259+
dim = query.shape[-1]
260+
259261
query = self.reshape_heads_to_batch_dim(query)
260262
key = self.reshape_heads_to_batch_dim(key)
261263
value = self.reshape_heads_to_batch_dim(value)

0 commit comments

Comments
 (0)