Skip to content

Conversation

@LyrisZhong
Copy link
Contributor

@LyrisZhong LyrisZhong commented Jul 2, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

  1. modify SM90 CUTLASS FP8 GEMM kernel dispatch code to suport swap AB optimization which can reduce padding overhead for M <= 64 cases (swap AB is currently supported for the SM100 FP8 blockwise kernel, related PR)
  2. update kernel configs based on CUTLASS profiler results

Test Plan

  • performance: used vllm/benchmarks/kernels/bench_fp8_gemm.py to test problem spaces for /Mistral-7B-v0.1 and Llama-2-7b-hf
  • accuracy: compared results from cutlass_scaled_mm and triton_scaled_mm to confirm results are identical

Test Result

5% to 25% improvement across different M/N/K and quantization schemes for Mistral-7B-v0.1 and Llama-2-7b-hf

Result for Mistral-7B-v0.1 - Before optimization
mixtral_without_optimization

Result for Mistral-7B-v0.1 - After optimization
mixtral_with_optimization

Result for Llama-2-7b-hf- Before optimization
llama2_7b_without_optimization

Result for Llama-2-7b-hf- After optimization
llama2_7b_with_optimization

No diff observed when comparing output from cutlass_sacled_mm and triton_scaled_mm for the following M/N/K:
M_values = [1, 2, 4, 8, 16, 32, 64, 128, 256]
N_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]
K_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]

Tests script:

import torch
import importlib
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant

device = "cuda"
dtype = torch.bfloat16

M_values = [1, 2, 4, 8, 16, 32, 64, 128, 256]
N_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]
K_values = [1024, 1280, 2048, 4096, 8192, 16384, 32768]

def cutlass_scaled_mm_sm_90_with_swap_ab(a_fp8, b_fp8, scale_a, scale_b):
	return vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, dtype)

def triton_scaled_mm(a_fp8, b_fp8, scale_a, scale_b):
	triton_scaled_mm_module = importlib.import_module(
		"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm")
	triton_scaled_mm_fn = triton_scaled_mm_module.triton_scaled_mm
	return triton_scaled_mm_fn(a_fp8, b_fp8, scale_a, scale_b, dtype)

for M in M_values:
	for N in N_values:
		for K in K_values:
			print(f"\n==== Testing M={M}, N={N}, K={K} ====")
			try:
				# Generate input tensors
				a = torch.randn((M, K), dtype=dtype, device=device)
				b = torch.randn((N, K), dtype=dtype, device=device)

				# Quantize to FP8
				scale_a = torch.ones(1, dtype=torch.float32, device=device)
				scale_b = torch.ones(1, dtype=torch.float32, device=device)

				a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
				b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
				b_fp8 = b_fp8.t()

				# Run kernels
				out_cutlass = cutlass_scaled_mm_sm_90_with_swap_ab(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8)
				out_triton = triton_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8)

				# Convert to float32
				out_cutlass_fp32 = out_cutlass.to(torch.float32)
				out_triton_fp32 = out_triton.to(torch.float32)

				# Compute differences
				abs_diff = torch.abs(out_cutlass_fp32 - out_triton_fp32)
				rel_diff = abs_diff / (out_triton_fp32.abs().clamp_min(1e-5))

				print(f"Max abs diff: {abs_diff.max().item():.6f}")
				print(f"Mean abs diff: {abs_diff.mean().item():.6f}")
				print(f"Max relative diff: {rel_diff.max().item() * 100:.4f}%")
				print(f"Mean relative diff: {rel_diff.mean().item() * 100:.4f}%")

				num_large_errors = (rel_diff > 0.001).sum().item()
				total_elements = rel_diff.numel()
				print(f"Elements with >0.1% relative diff: {num_large_errors}/{total_elements}")

				torch.testing.assert_close(out_cutlass_fp32, out_triton_fp32, rtol=1e-1, atol=1e-1)

			except RuntimeError as e:
				print(f"Runtime error at M={M}, N={N}, K={K}: {e}")
			except torch.testing.Error as e:
				print(f"Assertion failed at M={M}, N={N}, K={K}: {e}")

print("\n==== Completed all tests ====")

(Optional) Documentation Update

@github-actions
Copy link

github-actions bot commented Jul 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LyrisZhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces optimizations to the SM90 CUTLASS FP8 GEMM kernel by enabling swap AB and tuning kernel configurations based on profiler results. These changes aim to improve performance, especially for smaller M values, by reducing padding overhead. The update includes new kernel configurations tailored to different M and N ranges and modifications to the kernel dispatch logic.

Highlights

  • Swap AB Optimization: Modified the SM90 CUTLASS FP8 GEMM kernel dispatch code to support swap AB optimization, which reduces padding overhead for M <= 64 cases. This optimization is already supported for the SM100 FP8 blockwise kernel.
  • Kernel Tuning: Updated kernel configurations based on CUTLASS profiler results to improve performance.
  • New Kernel Configurations: Added specific kernel configurations for different M and N values (sm90_fp8_config_M64_N1280, sm90_fp8_config_M64_N8192, sm90_fp8_config_M16_N1280, sm90_fp8_config_M16_N8192) to enable swap AB optimization for M < 64.
  • Dispatch Logic: Modified the kernel dispatch logic in cutlass_gemm_sm90_fp8_dispatch to select the appropriate kernel based on M and N values, enabling the swap AB optimization for specific cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for FP8 GEMM on SM90 by adding support for swapping A and B matrices, along with kernel tuning. There are correctness issues in the cutlass_gemm_caller_sm90_fp8 function related to tensor stride calculations and argument passing to the CUTLASS kernel when swap_ab is enabled.

@mgoin
Copy link
Member

mgoin commented Jul 9, 2025

cc @ilmarkov @pavanimajety

@djmmoss djmmoss force-pushed the sm90_cutlass_fp8_gemm_swap_ab_integration branch from 15922ea to 2f28bfe Compare July 16, 2025 14:28
@LyrisZhong
Copy link
Contributor Author

@djmmoss @mgoin I have done additional tests for the latest commit:

1)Confirmed benchmark results (by running bench_fp8_gemm.py) mentioned in PR description are still valid

2)python3 -m pytest tests/kernels/quantization/test_cutlass_scaled_mm.py
only 1 test (test_cutlass_subset) failed - seems the test is related to int8 inputs which is not impacted by this PR and the same test also failed in latest vLLM mainline
3) lm_eval results:
lm_eval --model vllm --model_args pretrained=RedHatAI/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9,max_num_seqs=32 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match| |0.9212|±  |0.0074|
|            |       |strict-match    |     5|exact_match||0.9257|±  |0.0072|

lm_eval --model vllm --model_args pretrained=RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match |0.6171|±  |0.0134|70
|          |          |strict-match    |     5|exact_match|  |0.6149|±  |0.0134|70


@LyrisZhong LyrisZhong force-pushed the sm90_cutlass_fp8_gemm_swap_ab_integration branch from ffbd413 to 3269693 Compare July 19, 2025 00:31
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me, thanks for the update and evaluations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a fairly large increase in rtol.. will accept given the accuracy evals

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jul 21, 2025
@LyrisZhong
Copy link
Contributor Author

@mgoin is there anything pending from my side for the failed checks above? I haven't seen anything related to the PR yet

@mgoin mgoin added kernel and removed speculative-decoding ci/build v1 multi-modality Related to multi-modality (#4194) tool-calling llama Related to Llama models qwen Related to Qwen models deepseek Related to DeepSeek models labels Jul 28, 2025
@mgoin
Copy link
Member

mgoin commented Jul 28, 2025

@LyrisZhong There is general flakiness with the CI at the moment, I'll take care of getting this merged. Thanks!

@mgoin mgoin enabled auto-merge (squash) July 28, 2025 21:40
@mgoin mgoin merged commit c6c9122 into vllm-project:main Jul 28, 2025
96 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ng (vllm-project#20396)

Signed-off-by: Faqin Zhong <[email protected]>
Co-authored-by: Duncan Moss <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ng (vllm-project#20396)

Signed-off-by: Faqin Zhong <[email protected]>
Co-authored-by: Duncan Moss <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kernel performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants