-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. #14431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]> Signed-off-by: Thomas Parnell <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
please publish accuracy test as well. |
| skip_decode=True, | ||
| ) | ||
|
|
||
| block_size = value_cache.shape[3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to understand it right, should be a return after call of context_attention_fwd? otherwise for max_query_len > 1 you are calling two kernels that might compute the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, we added this option to the context_attention_fwd on main, which if enabled will skip the sequences in the batch with query_length=1. We then launch another kernel concurrently to handle the ones that were skipped.
|
Accuracy results Using V1 Using V1 cc @maleksan85 |
SageMoore
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Thanks for the contribution!
|
Here are the llm_eval results from an MI300X machine. Results look good |
|
@tdoublep you probably have already seen, if not then: https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#triton-kernel-performance-optimization |
|
@tdoublep do you understand the increase in mean TTFT versus main's |
|
One thing I recommend for these types of performance comparisons is adding |
@tlrmchlsmth hmm good catch, I hadn't noticed that. Will have another look.
Makes sense, will re-run with that enabled. |
|
@tlrmchlsmth I've re-run everything using benchmark command: ROCmAttentionBackend @ ROCmAttentionBackend @ ROCmAttentionBackend @ ROCmAttentionBackend @ ROCmAttentionBackend @ ROCmAttentionBackend @ I think it looks OK. |
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdoublep thanks for rerunning those tests!
Changes look good and the performance optimization makes sense
|
@tlrmchlsmth The multi-modal test that is failing does not look related to these changes. |
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]> Signed-off-by: Richard Liu <[email protected]>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Jan van Lunteren <[email protected]> Co-authored-by: Burkhard Ringlein <[email protected]> Co-authored-by: Chih-Chieh Yang <[email protected]> Signed-off-by: Mu Huai <[email protected]>
TLDR: This PR adds some further optimizations to
chunked_prefill_paged_decodeop to better handle models with GQA. Serving benchmarks using V1 indicate that with these changes, we see a 25% improvement in throughput forllama3.1-8bon an H100 vs. the current Triton implementation. With these changes, the throughput of the Triton implementation is only 8% worse than the V1 CUDA backend (FlashAttention).Using
FlashAttentionBackendfrom main on H100:Using
ROCmAttentionBackendfrom main on H100:And finally, using
ROCmAttentionBackendfrom this PR on H100:cc @SageMoore @maleksan85