diff --git a/vit_pytorch/simple_flash_attn_vit_3d.py b/vit_pytorch/simple_flash_attn_vit_3d.py index 8381c4a9..6f91ef54 100644 --- a/vit_pytorch/simple_flash_attn_vit_3d.py +++ b/vit_pytorch/simple_flash_attn_vit_3d.py @@ -1,17 +1,20 @@ from packaging import version -from collections import namedtuple import torch from torch import nn import torch.nn.functional as F from torch.nn import Module, ModuleList -from einops import rearrange +from einops import einsum, rearrange from einops.layers.torch import Rearrange # constants -Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) +Config = [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, +] # Flash Attention Config # helpers @@ -44,7 +47,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32): # main class class Attend(Module): - def __init__(self, use_flash = False, config: Config = Config(True, True, True)): + def __init__(self, use_flash = False, config: list = Config): super().__init__() self.config = config self.use_flash = use_flash @@ -53,7 +56,7 @@ def __init__(self, use_flash = False, config: Config = Config(True, True, True)) def flash_attn(self, q, k, v): # flash attention - https://arxiv.org/abs/2205.14135 - with torch.backends.cuda.sdp_kernel(**self.config._asdict()): + with torch.nn.attention.sdpa_kernel(self.config): out = F.scaled_dot_product_attention(q, k, v) return out