Skip to content
1 change: 1 addition & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Currently supports testing most attention, gemm, and fused MOE APIs:
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`.
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.
- `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models.
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`.
- GEMM:
Expand Down
8 changes: 0 additions & 8 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,10 +696,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
backends.remove("trtllm-gen")
if "trtllm-gen-native" in backends:
remove_trtllm_native = False
if batch_size == 1:
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
remove_trtllm_native = True
if not causal:
print("[INFO] trtllm-gen-native backend currently requires causal = True")
remove_trtllm_native = True
Expand Down Expand Up @@ -1184,10 +1180,6 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
]:
print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.")
remove_trtllm_native = True
if batch_size == 1:
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
remove_trtllm_native = True
if not (head_dim_qk == 192 and head_dim_vo == 128):
print(
"[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128"
Expand Down
6 changes: 3 additions & 3 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_available_cubin_files(


class ArtifactPath:
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
)
Expand All @@ -95,7 +95,7 @@ class ArtifactPath:
class MetaInfoHash:
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
TRTLLM_GEN_FMHA: str = (
"d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98"
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
)
TRTLLM_GEN_BMM: str = (
"4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
Expand All @@ -107,7 +107,7 @@ class MetaInfoHash:

class CheckSumHash:
TRTLLM_GEN_FMHA: str = (
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
)
TRTLLM_GEN_BMM: str = (
"8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934"
Expand Down
4 changes: 4 additions & 0 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ struct KernelParams {
int32_t mStartTokenIdxSfO;
// The sum of sequence lengths for Q and K/V.
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
// The flag to use block sparse attention.
bool mUseBlockSparseAttention;

// Create the TMA shape/stride for Q.
template <class FmhaOptions>
Expand Down Expand Up @@ -699,6 +701,8 @@ struct KernelParams {
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
params.mScaleSfKv = options.mScaleSfKv;
params.ptrSoftmaxStats = options.softmaxStatsPtr;
// TODO: Integrate trtllm block-sparse attention kernels when needed.
params.mUseBlockSparseAttention = false;
return params;
}
};
59 changes: 55 additions & 4 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def unpack_compare_nvfp4(
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_q_len", [511])
@pytest.mark.parametrize("max_kv_len", [2047])
def test_trtllm_batch_prefill(
kv_layout,
batch_size,
Expand All @@ -343,20 +345,20 @@ def test_trtllm_batch_prefill(
kv_dtype,
enable_pdl,
enable_sink,
max_q_len,
max_kv_len,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] in [11, 12]:
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
# Set up test parameters
torch.manual_seed(0)
head_dim = 128
MAX_Q_LEN = 511
MAX_IN_KV_LEN = 2047

# Generate random sequence lengths
num_qo_heads = num_kv_heads * head_grp_size
q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill(
batch_size, MAX_Q_LEN, MAX_IN_KV_LEN
batch_size, max_q_len, max_kv_len
)

# Create query tensor and related data
Expand Down Expand Up @@ -525,6 +527,56 @@ def test_trtllm_batch_prefill(
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,page_size,num_kv_heads,head_grp_size",
[
(1, 16, 8, 8),
],
)
@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
],
)
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_q_len", [8192])
@pytest.mark.parametrize("max_kv_len", [8192])
def test_trtllm_batch_prefill_bs1(
kv_layout,
batch_size,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_q_len,
max_kv_len,
):
test_trtllm_batch_prefill(
kv_layout,
batch_size,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_q_len,
max_kv_len,
)


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
Expand Down Expand Up @@ -998,7 +1050,6 @@ def test_trtllm_gen_prefill_deepseek(
def test_trtllm_gen_prefill_deepseek_bs1(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
):
pytest.xfail("trtllm-gen prefill triggers an IMA with bs1")
test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
)
Expand Down