Skip to content

Commit 4d81a5c

Browse files
commit wip
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent baa8525 commit 4d81a5c

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any, Optional
5+
6+
import torch
7+
8+
from vllm.attention.backends.abstract import AttentionType
9+
from vllm.attention.backends.utils import is_flash_attn_mla_supported
10+
from vllm.logger import init_logger
11+
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
12+
MLACommonDecodeMetadata,
13+
MLACommonImpl,
14+
MLACommonMetadata,
15+
MLACommonMetadataBuilder)
16+
from vllm.vllm_flash_attn import flash_attn_varlen_func
17+
18+
if TYPE_CHECKING:
19+
pass
20+
21+
logger = init_logger(__name__)
22+
23+
24+
class FlashMLABackend(MLACommonBackend):
25+
26+
@staticmethod
27+
def get_name() -> str:
28+
return "FLASHATTN_MLA_VLLM_V1"
29+
30+
@staticmethod
31+
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
32+
return FlashAttnMLAMetadata
33+
34+
@staticmethod
35+
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
36+
return FlashAttnMLAMetadataBuilder
37+
38+
@staticmethod
39+
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
40+
return FlashAttnMLAImpl
41+
42+
43+
@dataclass
44+
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
45+
pass
46+
47+
48+
@dataclass
49+
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
50+
pass
51+
52+
53+
class FlashAttnMLAMetadataBuilder(
54+
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
55+
56+
def __init__(self, runner):
57+
super().__init__(runner)
58+
59+
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
60+
self.runner.parallel_config)
61+
62+
def _build_decode(self, input_positions: torch.Tensor,
63+
block_table: torch.Tensor,
64+
seq_lens: torch.Tensor) -> FlashAttnMLADecodeMetadata:
65+
#
66+
67+
return FlashAttnMLADecodeMetadata(
68+
input_positions=input_positions,
69+
block_table=block_table,
70+
seq_lens=seq_lens,
71+
)
72+
73+
74+
class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]):
75+
76+
def __init__(
77+
self,
78+
num_heads: int,
79+
head_size: int,
80+
scale: float,
81+
num_kv_heads: int,
82+
alibi_slopes: Optional[list[float]],
83+
sliding_window: Optional[int],
84+
kv_cache_dtype: str,
85+
blocksparse_params: Optional[dict[str, Any]],
86+
logits_soft_cap: Optional[float],
87+
attn_type: str,
88+
# MLA Specific Arguments
89+
**mla_args) -> None:
90+
super().__init__(num_heads, head_size, scale, num_kv_heads,
91+
alibi_slopes, sliding_window, kv_cache_dtype,
92+
blocksparse_params, logits_soft_cap, attn_type,
93+
**mla_args)
94+
95+
assert is_flash_attn_mla_supported(), \
96+
"FlashAttnMLA is not supported on this device"
97+
98+
unsupported_features = [
99+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
100+
]
101+
if any(unsupported_features):
102+
raise NotImplementedError(
103+
"FlashMLAImpl does not support one of the following: "
104+
"alibi_slopes, sliding_window, blocksparse_params, "
105+
"logits_soft_cap")
106+
107+
if attn_type != AttentionType.DECODER:
108+
raise NotImplementedError("Encoder self-attention and "
109+
"encoder/decoder cross-attention "
110+
"are not implemented for "
111+
"FlashMLAImpl")
112+
113+
def _forward_decode(
114+
self,
115+
q_nope: torch.Tensor,
116+
q_pe: torch.Tensor,
117+
kv_c_and_k_pe_cache: torch.Tensor,
118+
attn_metadata: MLACommonMetadata,
119+
) -> torch.Tensor:
120+
assert kv_c_and_k_pe_cache.numel() > 0
121+
if self.kv_cache_dtype.startswith("fp8"):
122+
raise NotImplementedError("FP8 FlashMLA not yet supported")
123+
124+
decode_meta = attn_metadata.decode
125+
assert decode_meta is not None
126+
127+
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
128+
kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
129+
130+
o = flash_attn_varlen_func(
131+
q=q_pe,
132+
k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1
133+
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
134+
q_v=q_nope,
135+
max_seqlen_q=decode_meta.max_decode_query_len,
136+
cu_seqlens_q=decode_meta.query_start_loc,
137+
max_seqlen_k=decode_meta.max_decode_seq_len,
138+
seqused_k=decode_meta.seq_lens_tensor,
139+
block_table=decode_meta.block_tables,
140+
softmax_scale=self.scale,
141+
causal=True,
142+
fa_version=3 # only version 3 is supported
143+
)
144+
145+
return self._v_up_proj_and_o_proj(o)

0 commit comments

Comments
 (0)