From 2b2b422ebda88ea10bc2ee3142ee637ab625fb88 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 9 Oct 2025 22:04:54 +0000 Subject: [PATCH 01/12] Add print statements --- csrc/trtllm_fmha_kernel_launcher.cu | 3 +++ flashinfer/prefill.py | 2 ++ include/flashinfer/trtllm/fmha/fmhaKernels.cuh | 2 ++ 3 files changed, 7 insertions(+) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 05e92a1721..089a51acfe 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -375,6 +375,9 @@ void trtllm_ragged_attention_launcher( runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; + std::cout << "mMaxSeqLenQ, mMaxSeqLenKv:" << runner_params.mMaxSeqLenQ << ", " << runner_params.mMaxSeqLenKv << std::endl; + std::cout << "runner_params.cumSeqLensQPtr, runner_params.cumSeqLensKvPtr:" << runner_params.cumSeqLensQPtr << ", " << runner_params.cumSeqLensKvPtr << std::endl; + runner_params.kStrideKeysValues = k_stride_keys_values; runner_params.kStrideHeads = k_stride_heads; runner_params.kStrideBatch = k_stride_batch; diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 7399bd4268..9e860eb8b0 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3282,6 +3282,8 @@ def trtllm_ragged_attention_deepseek( ) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() + print(f"cum_seq_lens_q: {cum_seq_lens_q}") + print(f"cum_seq_lens_kv: {cum_seq_lens_kv}") run_func( out, query, diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index d3e2b89c85..8b9b95c177 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -191,6 +191,8 @@ class TllmGenFmhaKernel { // Prepare the kernel parameters. auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv); + std::cout << "params.ptrCumSeqLensQ:" << kernelParams.ptrCumSeqLensQ << std::endl; + std::cout << "params.ptrCumSeqLensKv:" << kernelParams.ptrCumSeqLensKv << std::endl; // Prepare kernel parameters list for cuLaunchKernelEx. void* kernelParamsList[] = {&kernelParams}; From 806096d836f7cfb304385b39996bbefa212fa4dc Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 9 Oct 2025 22:05:18 +0000 Subject: [PATCH 02/12] Add repro script --- tests/attention/run_trtllm_prefill_bs1.py | 162 ++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 tests/attention/run_trtllm_prefill_bs1.py diff --git a/tests/attention/run_trtllm_prefill_bs1.py b/tests/attention/run_trtllm_prefill_bs1.py new file mode 100644 index 0000000000..8832fc1da6 --- /dev/null +++ b/tests/attention/run_trtllm_prefill_bs1.py @@ -0,0 +1,162 @@ +import math + +import pytest +import torch + +import flashinfer +from flashinfer.utils import FP4Tensor, ceil_div, round_up, get_compute_capability + +global_workspace_buffer = None # can.be empty initialized +global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized +workspace_size = 256 * 1024 * 1024 + + +def create_workspace_buffers(device): + # Lazily initialize and reuse global workspace buffers + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + return global_trtllm_gen_fmha_workspace_buffer, global_workspace_buffer + + +def test_trtllm_gen_prefill_deepseek( + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal +): + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] in [11, 12]: + pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.") + if s_qo > s_kv: + pytest.skip("s_qo > s_kv, skipping test as causal") + + num_qo_heads = num_kv_heads * head_grp_size + head_dim_qk = 192 + head_dim_vo = 128 + + seed = 0 + torch.manual_seed(seed) + device = "cuda:0" + + actual_seq_lens_q = torch.randint( + 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + actual_seq_lens_kv = torch.randint( + s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + cumsum_s_qo = torch.sum(actual_seq_lens_q) + cumsum_s_kv = torch.sum(actual_seq_lens_kv) + + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 + ) + + k_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_qk), + device=device, + dtype=torch.bfloat16, + ) + v_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_vo), + device=device, + dtype=torch.bfloat16, + ) + + # Initialize scale + scale = float(1.0 / (head_dim_qk**0.5)) + + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(device) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + + # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv + + # Create kv_indptr as cumulative sum of actual_seq_lens_kv + kv_indptr = torch.cat( + [ + torch.tensor( + [0], + device=device, + ), + torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), + ] + ).int() + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer_ref, + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=scale, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) + output = torch.empty_like(output_ref) + + bmm1_scale = scale + bmm2_scale = 1.0 + output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( + q, + k_cache, + v_cache, + workspace_buffer, + actual_seq_lens_kv, + s_qo, + s_kv, + bmm1_scale, + bmm2_scale, + -1, + batch_size, + -1, + qo_indptr, + kv_indptr, + False, + causal, + True, + out=output, + ) + torch.testing.assert_close( + output_trtllm, + output_ref, + atol=1e-2, + rtol=1e-2, + ) + torch.testing.assert_close( + lse_trtllm, + lse_ref, + atol=1e-3, + rtol=1e-3, + ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + +if __name__ == "__main__": + test_trtllm_gen_prefill_deepseek( + batch_size=1, + s_qo=1024, + s_kv=1024, + num_kv_heads=128, + head_grp_size=1, + causal=True, + ) \ No newline at end of file From 16ec20650707bdec4aaa427cfd6d3b25ecbe8256 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 17:29:57 +0000 Subject: [PATCH 03/12] Prefill benchmark and test code has been changed --- benchmarks/routines/attention.py | 24 ++++++++++++-------- csrc/trtllm_fmha_kernel_launcher.cu | 6 +++-- flashinfer/prefill.py | 6 ++--- tests/attention/run_trtllm_prefill_bs1.py | 7 +++--- tests/attention/test_trtllm_gen_attention.py | 4 +++- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index a75dea0928..cdb9977f02 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -696,10 +696,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): backends.remove("trtllm-gen") if "trtllm-gen-native" in backends: remove_trtllm_native = False - if batch_size == 1: - # TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix. - print("[INFO] trtllm-gen-native backend currently requires batch size > 1") - remove_trtllm_native = True if not causal: print("[INFO] trtllm-gen-native backend currently requires causal = True") remove_trtllm_native = True @@ -932,6 +928,12 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): v_fp8 = (v_data / v_scale).to(kv_dtype) kv_cache = torch.cat([k_fp8, v_fp8], dim=1) + if batch_size == 1: + # trtllm kernel requires max_q_len and max_kv_len to be the same as cum_seq_lens_q and cum_seq_lens_kv when batch_size=1 + s_qo_trtllm = qo_indptr[-1].item() + else: + s_qo_trtllm = s_qo + def run_backend_wrapper(backend): if backend in ["fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run( @@ -962,7 +964,7 @@ def run_backend_wrapper(backend): workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=actual_seq_lens_kv_device, - max_q_len=s_qo, + max_q_len=s_qo_trtllm, max_kv_len=s_kv, bmm1_scale=scale if k_scale is None else k_scale * scale, bmm2_scale=1.0 if v_scale is None else v_scale, @@ -1184,10 +1186,6 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): ]: print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.") remove_trtllm_native = True - if batch_size == 1: - # TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix. - print("[INFO] trtllm-gen-native backend currently requires batch size > 1") - remove_trtllm_native = True if not (head_dim_qk == 192 and head_dim_vo == 128): print( "[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128" @@ -1382,6 +1380,12 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): k = (k / k_scale).to(kv_dtype) v = (v / v_scale).to(kv_dtype) + if batch_size == 1: + # trtllm kernel requires max_q_len and max_kv_len to be the same as cum_seq_lens_q and cum_seq_lens_kv when batch_size=1 + s_qo_trtllm = qo_indptr[-1].item() + else: + s_qo_trtllm = s_qo + def run_backend_wrapper(backend): if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run_return_lse(q, k, v)[0] @@ -1413,7 +1417,7 @@ def run_backend_wrapper(backend): value=v, workspace_buffer=workspace_buffer, seq_lens=actual_seq_lens_kv_device, - max_q_len=s_qo, + max_q_len=s_qo_trtllm, max_kv_len=s_kv, bmm1_scale=scale, bmm2_scale=1.0, diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 089a51acfe..d686c4961e 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -375,8 +375,10 @@ void trtllm_ragged_attention_launcher( runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; - std::cout << "mMaxSeqLenQ, mMaxSeqLenKv:" << runner_params.mMaxSeqLenQ << ", " << runner_params.mMaxSeqLenKv << std::endl; - std::cout << "runner_params.cumSeqLensQPtr, runner_params.cumSeqLensKvPtr:" << runner_params.cumSeqLensQPtr << ", " << runner_params.cumSeqLensKvPtr << std::endl; + std::cout << "mMaxSeqLenQ, mMaxSeqLenKv:" << runner_params.mMaxSeqLenQ << ", " + << runner_params.mMaxSeqLenKv << std::endl; + std::cout << "runner_params.cumSeqLensQPtr, runner_params.cumSeqLensKvPtr:" + << runner_params.cumSeqLensQPtr << ", " << runner_params.cumSeqLensKvPtr << std::endl; runner_params.kStrideKeysValues = k_stride_keys_values; runner_params.kStrideHeads = k_stride_heads; diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 9e860eb8b0..55cd18c7fd 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3221,9 +3221,9 @@ def trtllm_ragged_attention_deepseek( seq_lens : torch.Tensor sequence lengths max_q_len : int - max query length + max query length. If batch_size == 1, must be equal to cum_seq_lens_q[-1] max_kv_len : int - max key/value length + max key/value length. If batch_size == 1, must be equal to cum_seq_lens_kv[-1] bmm1_scale : float scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) bmm2_scale : float @@ -3282,8 +3282,6 @@ def trtllm_ragged_attention_deepseek( ) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() - print(f"cum_seq_lens_q: {cum_seq_lens_q}") - print(f"cum_seq_lens_kv: {cum_seq_lens_kv}") run_func( out, query, diff --git a/tests/attention/run_trtllm_prefill_bs1.py b/tests/attention/run_trtllm_prefill_bs1.py index 8832fc1da6..5f01daf391 100644 --- a/tests/attention/run_trtllm_prefill_bs1.py +++ b/tests/attention/run_trtllm_prefill_bs1.py @@ -1,10 +1,8 @@ -import math - import pytest import torch import flashinfer -from flashinfer.utils import FP4Tensor, ceil_div, round_up, get_compute_capability +from flashinfer.utils import get_compute_capability global_workspace_buffer = None # can.be empty initialized global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized @@ -151,6 +149,7 @@ def test_trtllm_gen_prefill_deepseek( # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + if __name__ == "__main__": test_trtllm_gen_prefill_deepseek( batch_size=1, @@ -159,4 +158,4 @@ def test_trtllm_gen_prefill_deepseek( num_kv_heads=128, head_grp_size=1, causal=True, - ) \ No newline at end of file + ) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 80853c7dbf..04eb68de17 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -952,6 +952,9 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 + if batch_size == 1: + # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 + s_qo = qo_indptr[-1].item() output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, k_cache, @@ -998,7 +1001,6 @@ def test_trtllm_gen_prefill_deepseek( def test_trtllm_gen_prefill_deepseek_bs1( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ): - pytest.xfail("trtllm-gen prefill triggers an IMA with bs1") test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) From 0bfdf6bf2f4780eaf2b8934d2824ac0414712bd3 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 17:36:30 +0000 Subject: [PATCH 04/12] Cleanup. Add paged prefill batch size 1 case --- benchmarks/routines/attention.py | 4 +- csrc/trtllm_fmha_kernel_launcher.cu | 5 --- flashinfer/prefill.py | 2 +- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 2 - tests/attention/test_trtllm_gen_attention.py | 38 +++++++++++++++++++ 5 files changed, 41 insertions(+), 10 deletions(-) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index cdb9977f02..50f002c4ad 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -929,7 +929,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): kv_cache = torch.cat([k_fp8, v_fp8], dim=1) if batch_size == 1: - # trtllm kernel requires max_q_len and max_kv_len to be the same as cum_seq_lens_q and cum_seq_lens_kv when batch_size=1 + # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 s_qo_trtllm = qo_indptr[-1].item() else: s_qo_trtllm = s_qo @@ -1381,7 +1381,7 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): v = (v / v_scale).to(kv_dtype) if batch_size == 1: - # trtllm kernel requires max_q_len and max_kv_len to be the same as cum_seq_lens_q and cum_seq_lens_kv when batch_size=1 + # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 s_qo_trtllm = qo_indptr[-1].item() else: s_qo_trtllm = s_qo diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index d686c4961e..05e92a1721 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -375,11 +375,6 @@ void trtllm_ragged_attention_launcher( runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; - std::cout << "mMaxSeqLenQ, mMaxSeqLenKv:" << runner_params.mMaxSeqLenQ << ", " - << runner_params.mMaxSeqLenKv << std::endl; - std::cout << "runner_params.cumSeqLensQPtr, runner_params.cumSeqLensKvPtr:" - << runner_params.cumSeqLensQPtr << ", " << runner_params.cumSeqLensKvPtr << std::endl; - runner_params.kStrideKeysValues = k_stride_keys_values; runner_params.kStrideHeads = k_stride_heads; runner_params.kStrideBatch = k_stride_batch; diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 55cd18c7fd..39c13f7953 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3223,7 +3223,7 @@ def trtllm_ragged_attention_deepseek( max_q_len : int max query length. If batch_size == 1, must be equal to cum_seq_lens_q[-1] max_kv_len : int - max key/value length. If batch_size == 1, must be equal to cum_seq_lens_kv[-1] + max key/value length. bmm1_scale : float scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) bmm2_scale : float diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 8b9b95c177..d3e2b89c85 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -191,8 +191,6 @@ class TllmGenFmhaKernel { // Prepare the kernel parameters. auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv); - std::cout << "params.ptrCumSeqLensQ:" << kernelParams.ptrCumSeqLensQ << std::endl; - std::cout << "params.ptrCumSeqLensKv:" << kernelParams.ptrCumSeqLensKv << std::endl; // Prepare kernel parameters list for cuLaunchKernelEx. void* kernelParamsList[] = {&kernelParams}; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 04eb68de17..22228fa710 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -525,6 +525,44 @@ def test_trtllm_batch_prefill( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,page_size,num_kv_heads,head_grp_size", + [ + (1, 16, 8, 8), + ], +) +@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +def test_trtllm_batch_prefill_bs1( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, +): + test_trtllm_batch_prefill( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + ) + + @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", From a9b19e7a815b2479595f28cabebb328b99e103c0 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 17:47:45 +0000 Subject: [PATCH 05/12] Cleanup test case --- tests/attention/test_trtllm_gen_attention.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 22228fa710..51b0de360b 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -331,6 +331,8 @@ def unpack_compare_nvfp4( ) @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_q_len", [511]) +@pytest.mark.parametrize("max_kv_len", [2047]) def test_trtllm_batch_prefill( kv_layout, batch_size, @@ -343,6 +345,8 @@ def test_trtllm_batch_prefill( kv_dtype, enable_pdl, enable_sink, + max_q_len, + max_kv_len, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] in [11, 12]: @@ -350,13 +354,11 @@ def test_trtllm_batch_prefill( # Set up test parameters torch.manual_seed(0) head_dim = 128 - MAX_Q_LEN = 511 - MAX_IN_KV_LEN = 2047 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( - batch_size, MAX_Q_LEN, MAX_IN_KV_LEN + batch_size, max_q_len, max_kv_len ) # Create query tensor and related data @@ -541,6 +543,8 @@ def test_trtllm_batch_prefill( ) @pytest.mark.parametrize("enable_pdl", [None]) @pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_q_len", [8192]) +@pytest.mark.parametrize("max_kv_len", [8192]) def test_trtllm_batch_prefill_bs1( kv_layout, batch_size, @@ -560,6 +564,12 @@ def test_trtllm_batch_prefill_bs1( page_size, num_kv_heads, head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, ) From 9cd0546eb7f02176f384255e170fd532922ee507 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 17:52:54 +0000 Subject: [PATCH 06/12] Cleanup for creating a prefill MR --- benchmarks/README.md | 1 + flashinfer/prefill.py | 2 +- tests/attention/run_trtllm_prefill_bs1.py | 161 ---------------------- 3 files changed, 2 insertions(+), 162 deletions(-) delete mode 100644 tests/attention/run_trtllm_prefill_bs1.py diff --git a/benchmarks/README.md b/benchmarks/README.md index 6b1a30c9d1..f41d695cdc 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -16,6 +16,7 @@ Currently supports testing most attention, gemm, and fused MOE APIs: - `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache. - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`. - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. - `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models. - Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`. - GEMM: diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 39c13f7953..36a111f741 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3347,7 +3347,7 @@ def trtllm_batch_context_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int - max sequence length for query + max sequence length for query. If batch_size == 1, must be equal to cum_seq_lens_q[-1] max_kv_len : int max sequence length for kv_cache bmm1_scale : float diff --git a/tests/attention/run_trtllm_prefill_bs1.py b/tests/attention/run_trtllm_prefill_bs1.py deleted file mode 100644 index 5f01daf391..0000000000 --- a/tests/attention/run_trtllm_prefill_bs1.py +++ /dev/null @@ -1,161 +0,0 @@ -import pytest -import torch - -import flashinfer -from flashinfer.utils import get_compute_capability - -global_workspace_buffer = None # can.be empty initialized -global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized -workspace_size = 256 * 1024 * 1024 - - -def create_workspace_buffers(device): - # Lazily initialize and reuse global workspace buffers - global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.empty( - workspace_size, dtype=torch.int8, device=device - ) - if global_trtllm_gen_fmha_workspace_buffer is None: - global_trtllm_gen_fmha_workspace_buffer = torch.zeros( - workspace_size, dtype=torch.int8, device=device - ) - return global_trtllm_gen_fmha_workspace_buffer, global_workspace_buffer - - -def test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal -): - compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] in [11, 12]: - pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.") - if s_qo > s_kv: - pytest.skip("s_qo > s_kv, skipping test as causal") - - num_qo_heads = num_kv_heads * head_grp_size - head_dim_qk = 192 - head_dim_vo = 128 - - seed = 0 - torch.manual_seed(seed) - device = "cuda:0" - - actual_seq_lens_q = torch.randint( - 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - actual_seq_lens_kv = torch.randint( - s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - cumsum_s_kv = torch.sum(actual_seq_lens_kv) - - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 - ) - - k_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_qk), - device=device, - dtype=torch.bfloat16, - ) - v_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_vo), - device=device, - dtype=torch.bfloat16, - ) - - # Initialize scale - scale = float(1.0 / (head_dim_qk**0.5)) - - workspace_buffer, workspace_buffer_ref = create_workspace_buffers(device) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), - ] - ).int() - - # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv - - # Create kv_indptr as cumulative sum of actual_seq_lens_kv - kv_indptr = torch.cat( - [ - torch.tensor( - [0], - device=device, - ), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), - ] - ).int() - - wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer_ref, - kv_layout="NHD", - backend="cutlass", - ) - wrapper.plan( - qo_indptr, - kv_indptr, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo=head_dim_vo, - causal=causal, - sm_scale=scale, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, - ) - output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - output = torch.empty_like(output_ref) - - bmm1_scale = scale - bmm2_scale = 1.0 - output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( - q, - k_cache, - v_cache, - workspace_buffer, - actual_seq_lens_kv, - s_qo, - s_kv, - bmm1_scale, - bmm2_scale, - -1, - batch_size, - -1, - qo_indptr, - kv_indptr, - False, - causal, - True, - out=output, - ) - torch.testing.assert_close( - output_trtllm, - output_ref, - atol=1e-2, - rtol=1e-2, - ) - torch.testing.assert_close( - lse_trtllm, - lse_ref, - atol=1e-3, - rtol=1e-3, - ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() - - -if __name__ == "__main__": - test_trtllm_gen_prefill_deepseek( - batch_size=1, - s_qo=1024, - s_kv=1024, - num_kv_heads=128, - head_grp_size=1, - causal=True, - ) From ac233715c929e9562d0be2782028c655e74de20c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 17:58:04 +0000 Subject: [PATCH 07/12] Fixing comment --- benchmarks/routines/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 50f002c4ad..086c5e269a 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -929,7 +929,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): kv_cache = torch.cat([k_fp8, v_fp8], dim=1) if batch_size == 1: - # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 + # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 s_qo_trtllm = qo_indptr[-1].item() else: s_qo_trtllm = s_qo @@ -1381,7 +1381,7 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): v = (v / v_scale).to(kv_dtype) if batch_size == 1: - # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 + # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 s_qo_trtllm = qo_indptr[-1].item() else: s_qo_trtllm = s_qo From 8e04cbdb8f3105a0c08e4e4767f05aa27d4cd6ec Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 10 Oct 2025 18:08:52 +0000 Subject: [PATCH 08/12] Add missing params --- tests/attention/test_trtllm_gen_attention.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 51b0de360b..0da439a26b 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -557,6 +557,8 @@ def test_trtllm_batch_prefill_bs1( kv_dtype, enable_pdl, enable_sink, + max_q_len, + max_kv_len, ): test_trtllm_batch_prefill( kv_layout, @@ -570,6 +572,8 @@ def test_trtllm_batch_prefill_bs1( kv_dtype, enable_pdl, enable_sink, + max_q_len, + max_kv_len, ) From 71fca41837a683284032c68cfa99e3dd1f057c53 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 16 Oct 2025 17:28:14 +0000 Subject: [PATCH 09/12] Update cubin artifacts to latest trtllm-gen fmha. Undo batch_size==1 special treatments --- benchmarks/routines/attention.py | 16 ++-------------- flashinfer/artifacts.py | 2 +- flashinfer/prefill.py | 4 ++-- tests/attention/test_trtllm_gen_attention.py | 3 --- 4 files changed, 5 insertions(+), 20 deletions(-) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 086c5e269a..acdf9ce7ab 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -928,12 +928,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): v_fp8 = (v_data / v_scale).to(kv_dtype) kv_cache = torch.cat([k_fp8, v_fp8], dim=1) - if batch_size == 1: - # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 - s_qo_trtllm = qo_indptr[-1].item() - else: - s_qo_trtllm = s_qo - def run_backend_wrapper(backend): if backend in ["fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run( @@ -964,7 +958,7 @@ def run_backend_wrapper(backend): workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=actual_seq_lens_kv_device, - max_q_len=s_qo_trtllm, + max_q_len=s_qo, max_kv_len=s_kv, bmm1_scale=scale if k_scale is None else k_scale * scale, bmm2_scale=1.0 if v_scale is None else v_scale, @@ -1380,12 +1374,6 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): k = (k / k_scale).to(kv_dtype) v = (v / v_scale).to(kv_dtype) - if batch_size == 1: - # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 - s_qo_trtllm = qo_indptr[-1].item() - else: - s_qo_trtllm = s_qo - def run_backend_wrapper(backend): if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run_return_lse(q, k, v)[0] @@ -1417,7 +1405,7 @@ def run_backend_wrapper(backend): value=v, workspace_buffer=workspace_buffer, seq_lens=actual_seq_lens_kv_device, - max_q_len=s_qo_trtllm, + max_q_len=s_qo, max_kv_len=s_kv, bmm1_scale=scale, bmm2_scale=1.0, diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 1b5cde7542..3d56f3ae9e 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -80,7 +80,7 @@ def get_available_cubin_files( class ArtifactPath: - TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 36a111f741..b945e7ff36 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3221,7 +3221,7 @@ def trtllm_ragged_attention_deepseek( seq_lens : torch.Tensor sequence lengths max_q_len : int - max query length. If batch_size == 1, must be equal to cum_seq_lens_q[-1] + max query length max_kv_len : int max key/value length. bmm1_scale : float @@ -3347,7 +3347,7 @@ def trtllm_batch_context_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int - max sequence length for query. If batch_size == 1, must be equal to cum_seq_lens_q[-1] + max sequence length for query max_kv_len : int max sequence length for kv_cache bmm1_scale : float diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 0da439a26b..6bd2065b3d 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1004,9 +1004,6 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 - if batch_size == 1: - # trtllm kernel requires max_q_len to be the same as cum_seq_lens_q when batch_size=1 - s_qo = qo_indptr[-1].item() output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, k_cache, From 56ed50a0149a6e7a5e66993aebd9028e838c4a1b Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 16 Oct 2025 17:29:07 +0000 Subject: [PATCH 10/12] Undo change in prefill --- flashinfer/prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index b945e7ff36..7399bd4268 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3223,7 +3223,7 @@ def trtllm_ragged_attention_deepseek( max_q_len : int max query length max_kv_len : int - max key/value length. + max key/value length bmm1_scale : float scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) bmm2_scale : float From 9cb2b382c80edb6c00845fee115323c5ce916c0a Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 16 Oct 2025 18:07:07 +0000 Subject: [PATCH 11/12] Update checksums --- flashinfer/artifacts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 3d56f3ae9e..89b458e8d0 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -95,7 +95,7 @@ class ArtifactPath: class MetaInfoHash: DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_FMHA: str = ( - "d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98" + "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" ) TRTLLM_GEN_BMM: str = ( "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" @@ -107,7 +107,7 @@ class MetaInfoHash: class CheckSumHash: TRTLLM_GEN_FMHA: str = ( - "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" + "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) TRTLLM_GEN_BMM: str = ( "8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934" From 1b7f9e86c5c78f4fbf70fa636561e5b7597edab9 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 17 Oct 2025 16:51:37 +0000 Subject: [PATCH 12/12] Add new mUseBlockSparseAttention; parameter to KernelParams --- include/flashinfer/trtllm/fmha/kernelParams.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 0d592c63e0..57adc57914 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -152,6 +152,8 @@ struct KernelParams { int32_t mStartTokenIdxSfO; // The sum of sequence lengths for Q and K/V. int32_t mSumOfSeqLensQ, mSumOfSeqLensKv; + // The flag to use block sparse attention. + bool mUseBlockSparseAttention; // Create the TMA shape/stride for Q. template @@ -699,6 +701,8 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; + // TODO: Integrate trtllm block-sparse attention kernels when needed. + params.mUseBlockSparseAttention = false; return params; } };