Skip to content

Commit a5f35ee

Browse files
entrpnjfacevedo-googlesayakpaula-r-r-o-w
authored
add reshape to fix use_memory_efficient_attention in flax (#7918)
Co-authored-by: Juan Acevedo <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent 6324340 commit a5f35ee

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, hidden_states, context=None, deterministic=True):
216216
hidden_states = jax_memory_efficient_attention(
217217
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218218
)
219-
220219
hidden_states = hidden_states.transpose(1, 0, 2)
220+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
221221
else:
222222
# compute attentions
223223
if self.split_head_dim:

0 commit comments

Comments
 (0)