Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from einops import rearrange
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel

from library import custom_offloading_utils

Expand Down Expand Up @@ -445,11 +446,13 @@ class ModelSpec:

# region math

kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe)

x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")

return x
Expand Down