Skip to content

Commit b6ce762

Browse files
LucasWilkinsonAkshat-Tripathi
authored andcommitted
[Kernel] FlashMLA integration (vllm-project#13747)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 5eb0d63 commit b6ce762

File tree

11 files changed

+733
-86
lines changed

11 files changed

+733
-86
lines changed

CMakeLists.txt

Lines changed: 4 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
575575
WITH_SOABI)
576576
endif()
577577

578-
# vllm-flash-attn currently only supported on CUDA
579-
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
580-
return()
578+
# For CUDA we also build and ship some external projects.
579+
if (VLLM_GPU_LANG STREQUAL "CUDA")
580+
include(cmake/external_projects/flashmla.cmake)
581+
include(cmake/external_projects/vllm_flash_attn.cmake)
581582
endif ()
582-
583-
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
584-
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
585-
# arches in the CUDA case (and instead set the gencodes on a per file basis)
586-
# we need to manually set VLLM_GPU_ARCHES here.
587-
if(VLLM_GPU_LANG STREQUAL "CUDA")
588-
foreach(_ARCH ${CUDA_ARCHS})
589-
string(REPLACE "." "" _ARCH "${_ARCH}")
590-
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
591-
endforeach()
592-
endif()
593-
594-
#
595-
# Build vLLM flash attention from source
596-
#
597-
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
598-
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
599-
# They should be identical but if they aren't, this is a massive footgun.
600-
#
601-
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
602-
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
603-
# If no component is specified, vllm-flash-attn is still installed.
604-
605-
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
606-
# This is to enable local development of vllm-flash-attn within vLLM.
607-
# It can be set as an environment variable or passed as a cmake argument.
608-
# The environment variable takes precedence.
609-
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
610-
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
611-
endif()
612-
613-
if(VLLM_FLASH_ATTN_SRC_DIR)
614-
FetchContent_Declare(
615-
vllm-flash-attn SOURCE_DIR
616-
${VLLM_FLASH_ATTN_SRC_DIR}
617-
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
618-
)
619-
else()
620-
FetchContent_Declare(
621-
vllm-flash-attn
622-
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
623-
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
624-
GIT_PROGRESS TRUE
625-
# Don't share the vllm-flash-attn build between build types
626-
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
627-
)
628-
endif()
629-
630-
631-
# Fetch the vllm-flash-attn library
632-
FetchContent_MakeAvailable(vllm-flash-attn)
633-
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
634-
635-
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
636-
# case only one is built, in the case both are built redundant work is done)
637-
install(
638-
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
639-
DESTINATION vllm_flash_attn
640-
COMPONENT _vllm_fa2_C
641-
FILES_MATCHING PATTERN "*.py"
642-
)
643-
644-
install(
645-
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
646-
DESTINATION vllm_flash_attn
647-
COMPONENT _vllm_fa3_C
648-
FILES_MATCHING PATTERN "*.py"
649-
)
650-
651-
# Nothing after vllm-flash-attn, see comment about macros above
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
include(FetchContent)
2+
3+
# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
4+
# instead of downloading.
5+
# It can be set as an environment variable or passed as a cmake argument.
6+
# The environment variable takes precedence.
7+
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
8+
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
9+
endif()
10+
11+
if(FLASH_MLA_SRC_DIR)
12+
FetchContent_Declare(
13+
flashmla
14+
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
15+
CONFIGURE_COMMAND ""
16+
BUILD_COMMAND ""
17+
)
18+
else()
19+
FetchContent_Declare(
20+
flashmla
21+
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22+
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
23+
GIT_PROGRESS TRUE
24+
CONFIGURE_COMMAND ""
25+
BUILD_COMMAND ""
26+
)
27+
endif()
28+
29+
30+
FetchContent_MakeAvailable(flashmla)
31+
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
32+
33+
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
34+
# Only build FlashMLA kernels if we are building for something compatible with
35+
# sm90a
36+
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
37+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
38+
set(FlashMLA_SOURCES
39+
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
40+
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
41+
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
42+
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
43+
44+
set(FlashMLA_INCLUDES
45+
${flashmla_SOURCE_DIR}/csrc/cutlass/include
46+
${flashmla_SOURCE_DIR}/csrc/include)
47+
48+
set_gencode_flags_for_srcs(
49+
SRCS "${FlashMLA_SOURCES}"
50+
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
51+
52+
define_gpu_extension_target(
53+
_flashmla_C
54+
DESTINATION vllm
55+
LANGUAGE ${VLLM_GPU_LANG}
56+
SOURCES ${FlashMLA_SOURCES}
57+
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
58+
ARCHITECTURES ${VLLM_GPU_ARCHES}
59+
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
60+
USE_SABI 3
61+
WITH_SOABI)
62+
else()
63+
# Create an empty target for setup.py when not targeting sm90a systems
64+
add_custom_target(_flashmla_C)
65+
endif()
66+
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
2+
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
3+
# arches in the CUDA case (and instead set the gencodes on a per file basis)
4+
# we need to manually set VLLM_GPU_ARCHES here.
5+
if(VLLM_GPU_LANG STREQUAL "CUDA")
6+
foreach(_ARCH ${CUDA_ARCHS})
7+
string(REPLACE "." "" _ARCH "${_ARCH}")
8+
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
9+
endforeach()
10+
endif()
11+
12+
#
13+
# Build vLLM flash attention from source
14+
#
15+
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
16+
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
17+
# They should be identical but if they aren't, this is a massive footgun.
18+
#
19+
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
20+
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
21+
# If no component is specified, vllm-flash-attn is still installed.
22+
23+
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
24+
# This is to enable local development of vllm-flash-attn within vLLM.
25+
# It can be set as an environment variable or passed as a cmake argument.
26+
# The environment variable takes precedence.
27+
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
28+
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
29+
endif()
30+
31+
if(VLLM_FLASH_ATTN_SRC_DIR)
32+
FetchContent_Declare(
33+
vllm-flash-attn SOURCE_DIR
34+
${VLLM_FLASH_ATTN_SRC_DIR}
35+
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
36+
)
37+
else()
38+
FetchContent_Declare(
39+
vllm-flash-attn
40+
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41+
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
42+
GIT_PROGRESS TRUE
43+
# Don't share the vllm-flash-attn build between build types
44+
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
45+
)
46+
endif()
47+
48+
49+
# Fetch the vllm-flash-attn library
50+
FetchContent_MakeAvailable(vllm-flash-attn)
51+
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
52+
53+
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
54+
# case only one is built, in the case both are built redundant work is done)
55+
install(
56+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
57+
DESTINATION vllm_flash_attn
58+
COMPONENT _vllm_fa2_C
59+
FILES_MATCHING PATTERN "*.py"
60+
)
61+
62+
install(
63+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
64+
DESTINATION vllm_flash_attn
65+
COMPONENT _vllm_fa3_C
66+
FILES_MATCHING PATTERN "*.py"
67+
)

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def run(self) -> None:
328328
files_to_copy = [
329329
"vllm/_C.abi3.so",
330330
"vllm/_moe_C.abi3.so",
331+
"vllm/_flashmla_C.abi3.so",
331332
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
332333
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
333334
"vllm/vllm_flash_attn/flash_attn_interface.py",
@@ -612,6 +613,11 @@ def _read_requirements(filename: str) -> List[str]:
612613
# FA3 requires CUDA 12.0 or later
613614
ext_modules.append(
614615
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
616+
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
617+
# Optional since this doesn't get built (produce an .so file) when
618+
# not targeting a hopper system
619+
ext_modules.append(
620+
CMakeExtension(name="vllm._flashmla_C", optional=True))
615621
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
616622

617623
if _build_custom_ops():

tests/kernels/test_flashmla.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
2+
# SPDX-License-Identifier: Apache-2.0
3+
import math
4+
import random
5+
6+
import pytest
7+
import torch
8+
import triton
9+
10+
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
11+
get_mla_metadata,
12+
is_flashmla_supported)
13+
14+
15+
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
16+
x, y = x.double(), y.double()
17+
cos_diff = 1 - 2 * (x * y).sum().item() / max(
18+
(x * x + y * y).sum().item(), 1e-12)
19+
assert cos_diff < 1e-5
20+
21+
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
22+
if not is_flashmla_supported()[0] else "FlashMLA is supported"
23+
24+
25+
@pytest.mark.skipif(not is_flashmla_supported()[0],
26+
reason=FLASH_MLA_UNSUPPORTED_REASON)
27+
@pytest.mark.parametrize("b", [128])
28+
@pytest.mark.parametrize("s_q", [1, 2])
29+
@pytest.mark.parametrize("mean_sk", [4096, 8192])
30+
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
31+
@pytest.mark.parametrize("h_kv", [1])
32+
@pytest.mark.parametrize("d", [576])
33+
@pytest.mark.parametrize("dv", [512])
34+
@pytest.mark.parametrize("block_size", [64])
35+
@pytest.mark.parametrize("causal", [True])
36+
@pytest.mark.parametrize("varlen", [False, True])
37+
@torch.inference_mode()
38+
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
39+
varlen):
40+
# TODO: parametrize using pytest
41+
dtype = torch.bfloat16
42+
device = torch.device("cuda:0")
43+
torch.set_default_dtype(dtype)
44+
torch.set_default_device(device)
45+
torch.cuda.set_device(device)
46+
torch.manual_seed(0)
47+
random.seed(0)
48+
49+
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
50+
f"{d=}, {dv=}, {causal=}, {varlen=}")
51+
52+
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
53+
if varlen:
54+
for i in range(b):
55+
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
56+
s_q)
57+
total_seqlens = cache_seqlens.sum().item()
58+
max_seqlen = cache_seqlens.max().item()
59+
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
60+
61+
q = torch.randn(b, s_q, h_q, d)
62+
block_table = torch.arange(b * max_seqlen_pad // block_size,
63+
dtype=torch.int32).view(
64+
b, max_seqlen_pad // block_size)
65+
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
66+
for i in range(b):
67+
blocked_k.view(b, max_seqlen_pad, h_kv,
68+
d)[i, cache_seqlens[i].item():] = float("nan")
69+
blocked_v = blocked_k[..., :dv]
70+
71+
tile_scheduler_metadata, num_splits = get_mla_metadata(
72+
cache_seqlens, s_q * h_q // h_kv, h_kv)
73+
74+
def flash_mla():
75+
return flash_mla_with_kvcache(
76+
q,
77+
blocked_k,
78+
block_table,
79+
cache_seqlens,
80+
dv,
81+
tile_scheduler_metadata,
82+
num_splits,
83+
causal=causal,
84+
)
85+
86+
def scaled_dot_product_attention(query, key, value, is_causal=False):
87+
query = query.float()
88+
key = key.float()
89+
value = value.float()
90+
key = key.repeat_interleave(h_q // h_kv, dim=0)
91+
value = value.repeat_interleave(h_q // h_kv, dim=0)
92+
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
93+
if is_causal:
94+
s_q = query.shape[-2]
95+
s_k = key.shape[-2]
96+
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
97+
temp_mask = torch.ones(s_q, s_k,
98+
dtype=torch.bool).tril(diagonal=s_k - s_q)
99+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
100+
attn_bias.to(query.dtype)
101+
attn_weight += attn_bias
102+
lse = attn_weight.logsumexp(dim=-1)
103+
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
104+
return attn_weight @ value, lse
105+
106+
def ref_mla():
107+
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
108+
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
109+
for i in range(b):
110+
begin = i * max_seqlen_pad
111+
end = begin + cache_seqlens[i]
112+
ref_O, LSE = scaled_dot_product_attention(
113+
q[i].transpose(0, 1),
114+
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
115+
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
116+
is_causal=causal,
117+
)
118+
out[i] = ref_O.transpose(0, 1)
119+
lse[i] = LSE
120+
return out, lse
121+
122+
out_flash, lse_flash = flash_mla()
123+
out_torch, lse_torch = ref_mla()
124+
cal_diff(out_flash, out_torch, "out")
125+
cal_diff(lse_flash, lse_torch, "lse")
126+
127+
t = triton.testing.do_bench(flash_mla, fast_flush=False)
128+
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
129+
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
130+
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
131+
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
132+
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")

0 commit comments

Comments
 (0)