Skip to content

Conversation

bkryu
Copy link
Collaborator

@bkryu bkryu commented Oct 10, 2025

📌 Description

Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898

Root cause of the issue: flashinfer.prefill.trtllm_ragged_attention_deepseek and flashinfer.prefill.trtllm_batch_context_with_kv_cache both require max_q_len to match the length of the query when batch size is 1.

Updated PR:
Issue has been addressed from the kernel-side so that the "max_q_len to match the length of the query when batch size is 1" is no longer required.

Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata.

Unit test results after PR:

$ pytest tests/attention/test_trtllm_gen_attention.py 
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items   
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)

Description of previous solution:
Updating max_q_len to cum_seq_lens_q[-1].item() within the trtllm_ragged_attention_deepseek or trtllm_batch_context_with_kv_cache functions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.

🔍 Related Issues

#1898

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Removed the automatic batch_size=1 restriction for a native backend, enabling its use in more scenarios while other constraints remain.
  • New Features

    • Added configurable block-sparse attention support to kernel parameters.
  • Documentation

    • Clarified supported attention optimizations and backend capabilities in the benchmarks docs.
  • Tests

    • Expanded tests with configurable sequence lengths and added dedicated batch-size-1 test coverage.

@bkryu bkryu self-assigned this Oct 10, 2025
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)

if batch_size == 1:
# trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why qo_indptr[-1] could be different to s_qo, is it because we want to be compatible with cudagraphs and s_qo will always be the maximum length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short answer is yes.

Longer answer: In a batch_size > 1 situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache() can be reused with multiple sequence lengths but not when batch_size==1. For example,

  • If batch_size is 3 and we have two batches with query lengths [100, 200, 300] and [16, 500, 1024], we can set s_qo=1024, when we construct the CUDA graph and use the same CUDA graph for the two batches.
  • However for batch_size=1, where we have batches of query lengths [100] and [1024], a CUDA graph must be constructed each time -- first with s_qo=100 and second with s_qo=1024.

Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo goes in as the max_q_len input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄

@bkryu bkryu force-pushed the trtllm-attention-debug branch from 4dade1b to 197a7a0 Compare October 16, 2025 17:23
@yzh119
Copy link
Collaborator

yzh119 commented Oct 16, 2025

Hi @bkryu does upgrading to latest trtllm-gen fixing the issue?

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 16, 2025

Hi @bkryu does upgrading to latest trtllm-gen fixing the issue?

Hi @yzh119, I'm currently checking. Upgrading to the latest trtllm-gen does fix the batch size 1 unit test, but I am seeing some errors in otherplaces. Will verify what is happening before marking the PR as ready

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 16, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !83 has been created, and the CI pipeline #36750562 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36750562: 1/17 passed

Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Removes the automatic skip of trtllm-gen-native for batch_size==1 in benchmark routines, updates three TRTLLM_GEN_FMHA artifact hash constants, adds a bool mUseBlockSparseAttention to KernelParams, expands tests to parameterize sequence lengths and add a bs1 test, and updates benchmark docs.

Changes

Cohort / File(s) Summary
Backend constraint change
benchmarks/routines/attention.py
Removed the auto-skip that excluded trtllm-gen-native when batch_size == 1 in testBatchPrefillWithPagedKVCacheWrapper and testBatchPrefillWithRaggedKVCacheWrapper, leaving other constraints unchanged.
Documentation
benchmarks/README.md
Added a note that BatchPrefillWithRaggedKVCacheWrapper also supports trtllm_ragged_attention_deepseek (in addition to cudnn_batch_prefill_with_kv_cache and others).
Artifact hash updates
flashinfer/artifacts.py
Updated three hard-coded constants for TRTLLM_GEN_FMHA: ArtifactPath.TRTLLM_GEN_FMHA, MetaInfoHash.TRTLLM_GEN_FMHA, and CheckSumHash.TRTLLM_GEN_FMHA (string value changes only).
Kernel parameters
include/flashinfer/trtllm/fmha/kernelParams.h
Added public data member bool mUseBlockSparseAttention to KernelParams, initialized to false in setKernelParams; no control-flow changes.
Tests
tests/attention/test_trtllm_gen_attention.py
Added max_q_len and max_kv_len parameters to test_trtllm_batch_prefill, added test_trtllm_batch_prefill_bs1 (delegates to the parametrized test), and removed an explicit xfail for a bs1 deepseek case.

Sequence Diagram(s)

sequenceDiagram
  participant Test as Test Runner
  participant Bench as Benchmark Routine
  participant Selector as Backend Selector
  participant Backend as trtllm-gen-native
  participant Kernel as KernelParams

  Note over Test,Bench: Test invokes batch prefill tests (parametrized)

  Test->>Bench: call testBatchPrefill(...)
  Bench->>Selector: determine eligible backends (now includes bs==1)
  Selector-->>Backend: select trtllm-gen-native when constraints met
  Backend->>Kernel: construct KernelParams (mUseBlockSparseAttention=false)
  Kernel-->>Backend: return params
  Backend-->>Bench: run prefill using params
  Bench-->>Test: report results
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Poem

🐇 I nudged the benches, let one hop through,
Hashes refreshed, flags set false by default too,
Tests stretch lengths, a tiny rabbit cheer,
Block-sparse dreams tucked near the gear,
Hop on, trtllm — the path is new!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "fix: Fix trtllm-gen prefill IMA when batch_size==1" directly and clearly describes the main objective of the changeset. The changes focus on resolving the batch_size==1 issue with trtllm-gen prefill by updating kernel cubins, removing unnecessary auto-skip logic, and adjusting test infrastructure to support this fix. The title is concise, uses standard commit convention ("fix:"), and accurately captures the primary change without being vague or overly broad.
Description Check ✅ Passed The PR description follows the required template structure with comprehensive sections for Description, Related Issues, and Pre-commit Checks. The Description section thoroughly explains the problem, root cause, solution, and includes unit test results showing 2055 passed tests. The Related Issues section properly links to issue #1898. The Pre-commit Checks are marked as completed. However, the Tests checklist section shows both items unchecked despite the description documenting test results and passing tests. This is a minor checkbox oversight that doesn't substantially detract from the overall completeness of the description.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 17, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !83 has been updated with latest changes, and the CI pipeline #36805526 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu marked this pull request as ready for review October 17, 2025 16:59
@bkryu bkryu changed the title fix: Fix test and benchmark for trtllm-gen prefill batch size 1 fix: Fix trtllm-gen prefill IMA when batch_size==1 Oct 17, 2025
@bkryu bkryu force-pushed the trtllm-attention-debug branch from 57e47ea to 003ef55 Compare October 17, 2025 21:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
benchmarks/README.md (1)

19-19: LGTM! Documentation correctly updated.

The documentation now accurately reflects that BatchPrefillWithRaggedKVCacheWrapper supports trtllm_ragged_attention_deepseek for ragged attention operations.

Optional: Fix list indentation for consistency.

The static analysis tool flags that this line uses 8 spaces for indentation instead of the expected 4 for nested list items.

-        - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and  `trtllm_ragged_attention_deepseek`.
+    - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and  `trtllm_ragged_attention_deepseek`.
tests/attention/test_trtllm_gen_attention.py (1)

348-361: LGTM! Function signature correctly updated.

The new max_q_len and max_kv_len parameters are properly integrated into the function signature and correctly passed to generate_seq_lens_prefill.

Optional: Prefix unused variable with underscore.

Line 360 unpacks in_kv_lens from generate_seq_lens_prefill, but the variable is never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused:

-    q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill(
+    q_lens, _in_kv_lens, seq_lens = generate_seq_lens_prefill(
         batch_size, max_q_len, max_kv_len
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 57e47ea and 003ef55.

📒 Files selected for processing (5)
  • benchmarks/README.md (1 hunks)
  • benchmarks/routines/attention.py (0 hunks)
  • flashinfer/artifacts.py (3 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (2 hunks)
  • tests/attention/test_trtllm_gen_attention.py (3 hunks)
💤 Files with no reviewable changes (1)
  • benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • flashinfer/artifacts.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md

19-19: Unordered list indentation
Expected: 4; Actual: 8

(MD007, ul-indent)

🪛 Ruff (0.14.0)
tests/attention/test_trtllm_gen_attention.py

360-360: Unpacked variable in_kv_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)

334-335: LGTM! Parameterization enhances test flexibility.

Adding max_q_len and max_kv_len as test parameters allows testing different sequence length combinations, which is essential for validating the batch_size==1 fix across various configurations.


530-578: LGTM! Dedicated batch_size=1 test addresses PR objective.

The new test_trtllm_batch_prefill_bs1 function specifically tests the batch_size==1 scenario with large sequence lengths (8192), which directly addresses the issue described in #1898. The test properly delegates to the main test function with appropriate parameters and minimal configuration to focus on the batch_size==1 edge case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants