-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning #20396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning #20396
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LyrisZhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces optimizations to the SM90 CUTLASS FP8 GEMM kernel by enabling swap AB and tuning kernel configurations based on profiler results. These changes aim to improve performance, especially for smaller M values, by reducing padding overhead. The update includes new kernel configurations tailored to different M and N ranges and modifications to the kernel dispatch logic.
Highlights
- Swap AB Optimization: Modified the SM90 CUTLASS FP8 GEMM kernel dispatch code to support swap AB optimization, which reduces padding overhead for M <= 64 cases. This optimization is already supported for the SM100 FP8 blockwise kernel.
- Kernel Tuning: Updated kernel configurations based on CUTLASS profiler results to improve performance.
- New Kernel Configurations: Added specific kernel configurations for different M and N values (sm90_fp8_config_M64_N1280, sm90_fp8_config_M64_N8192, sm90_fp8_config_M16_N1280, sm90_fp8_config_M16_N8192) to enable swap AB optimization for M < 64.
- Dispatch Logic: Modified the kernel dispatch logic in
cutlass_gemm_sm90_fp8_dispatchto select the appropriate kernel based on M and N values, enabling the swap AB optimization for specific cases.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimization for FP8 GEMM on SM90 by adding support for swapping A and B matrices, along with kernel tuning. There are correctness issues in the cutlass_gemm_caller_sm90_fp8 function related to tensor stride calculations and argument passing to the CUTLASS kernel when swap_ab is enabled.
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
Outdated
Show resolved
Hide resolved
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
Outdated
Show resolved
Hide resolved
15922ea to
2f28bfe
Compare
|
@djmmoss @mgoin I have done additional tests for the latest commit: 1)Confirmed benchmark results (by running bench_fp8_gemm.py) mentioned in PR description are still valid 2)python3 -m pytest tests/kernels/quantization/test_cutlass_scaled_mm.py lm_eval --model vllm --model_args pretrained=RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto |
ffbd413 to
3269693
Compare
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me, thanks for the update and evaluations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a fairly large increase in rtol.. will accept given the accuracy evals
|
@mgoin is there anything pending from my side for the failed checks above? I haven't seen anything related to the PR yet |
|
@LyrisZhong There is general flakiness with the CI at the moment, I'll take care of getting this merged. Thanks! |
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]> Signed-off-by: x22x22 <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]> Signed-off-by: Noam Gat <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]> Signed-off-by: Paul Pak <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
…ng (vllm-project#20396) Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
Test Plan
Test Result
5% to 25% improvement across different M/N/K and quantization schemes for Mistral-7B-v0.1 and Llama-2-7b-hf
Result for Mistral-7B-v0.1 - Before optimization

Result for Mistral-7B-v0.1 - After optimization

Result for Llama-2-7b-hf- Before optimization

Result for Llama-2-7b-hf- After optimization

No diff observed when comparing output from cutlass_sacled_mm and triton_scaled_mm for the following M/N/K:
M_values = [1, 2, 4, 8, 16, 32, 64, 128, 256]
N_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]
K_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]
Tests script:
(Optional) Documentation Update