-
Notifications
You must be signed in to change notification settings - Fork 532
fix: Fix trtllm-gen prefill IMA when batch_size==1 #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
benchmarks/routines/attention.py
Outdated
kv_cache = torch.cat([k_fp8, v_fp8], dim=1) | ||
|
||
if batch_size == 1: | ||
# trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why qo_indptr[-1]
could be different to s_qo
, is it because we want to be compatible with cudagraphs and s_qo
will always be the maximum length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short answer is yes.
Longer answer: In a batch_size > 1
situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache()
can be reused with multiple sequence lengths but not when batch_size==1
. For example,
- If batch_size is 3 and we have two batches with query lengths
[100, 200, 300]
and[16, 500, 1024]
, we can sets_qo=1024
, when we construct the CUDA graph and use the same CUDA graph for the two batches. - However for batch_size=1, where we have batches of query lengths
[100]
and[1024]
, a CUDA graph must be constructed each time -- first withs_qo=100
and second withs_qo=1024
.
Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo
goes in as the max_q_len
input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄
4dade1b
to
197a7a0
Compare
Hi @bkryu does upgrading to latest trtllm-gen fixing the issue? |
/bot run |
[FAILED] Pipeline #36750562: 1/17 passed |
WalkthroughRemoves the automatic skip of trtllm-gen-native for batch_size==1 in benchmark routines, updates three TRTLLM_GEN_FMHA artifact hash constants, adds a bool Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Bench as Benchmark Routine
participant Selector as Backend Selector
participant Backend as trtllm-gen-native
participant Kernel as KernelParams
Note over Test,Bench: Test invokes batch prefill tests (parametrized)
Test->>Bench: call testBatchPrefill(...)
Bench->>Selector: determine eligible backends (now includes bs==1)
Selector-->>Backend: select trtllm-gen-native when constraints met
Backend->>Kernel: construct KernelParams (mUseBlockSparseAttention=false)
Kernel-->>Backend: return params
Backend-->>Bench: run prefill using params
Bench-->>Test: report results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
/bot run |
…special treatments
57e47ea
to
003ef55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
benchmarks/README.md (1)
19-19
: LGTM! Documentation correctly updated.The documentation now accurately reflects that
BatchPrefillWithRaggedKVCacheWrapper
supportstrtllm_ragged_attention_deepseek
for ragged attention operations.Optional: Fix list indentation for consistency.
The static analysis tool flags that this line uses 8 spaces for indentation instead of the expected 4 for nested list items.
- - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.tests/attention/test_trtllm_gen_attention.py (1)
348-361
: LGTM! Function signature correctly updated.The new
max_q_len
andmax_kv_len
parameters are properly integrated into the function signature and correctly passed togenerate_seq_lens_prefill
.Optional: Prefix unused variable with underscore.
Line 360 unpacks
in_kv_lens
fromgenerate_seq_lens_prefill
, but the variable is never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused:- q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( + q_lens, _in_kv_lens, seq_lens = generate_seq_lens_prefill( batch_size, max_q_len, max_kv_len )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
benchmarks/README.md
(1 hunks)benchmarks/routines/attention.py
(0 hunks)flashinfer/artifacts.py
(3 hunks)include/flashinfer/trtllm/fmha/kernelParams.h
(2 hunks)tests/attention/test_trtllm_gen_attention.py
(3 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/trtllm/fmha/kernelParams.h
- flashinfer/artifacts.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
get_compute_capability
(251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🪛 Ruff (0.14.0)
tests/attention/test_trtllm_gen_attention.py
360-360: Unpacked variable in_kv_lens
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)
334-335
: LGTM! Parameterization enhances test flexibility.Adding
max_q_len
andmax_kv_len
as test parameters allows testing different sequence length combinations, which is essential for validating the batch_size==1 fix across various configurations.
530-578
: LGTM! Dedicated batch_size=1 test addresses PR objective.The new
test_trtllm_batch_prefill_bs1
function specifically tests the batch_size==1 scenario with large sequence lengths (8192), which directly addresses the issue described in #1898. The test properly delegates to the main test function with appropriate parameters and minimal configuration to focus on the batch_size==1 edge case.
📌 Description
Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898
Root cause of the issue:
flashinfer.prefill.trtllm_ragged_attention_deepseek
andflashinfer.prefill.trtllm_batch_context_with_kv_cache
both requiremax_q_len
to match the length of the query when batch size is 1.Updated PR:
Issue has been addressed from the kernel-side so that the "
max_q_len
to match the length of the query when batch size is 1" is no longer required.Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata.
Unit test results after PR:
Description of previous solution:
Updatingmax_q_len
tocum_seq_lens_q[-1].item()
within thetrtllm_ragged_attention_deepseek
ortrtllm_batch_context_with_kv_cache
functions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.🔍 Related Issues
#1898
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Documentation
Tests