Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Oct 31, 2025

Purpose

DeepGEMM requires the activation and weights scales to be in a specific format. When the scales are not provided in the desired format, DeepGEMM transforms the scales itself. This is usually very slow.

On H100, DeepGEMM needs the scales to be ColumnMajor and in float32.
On B200, DeepGEMM needs the scales to be ColumnMajor but in a packed E8M0 format.

Note that we handle the H100 case already on main. This PR adds partial support for the B200 case. Concretely,

  • Import transform_sf_into_required_layout from DeepGEMM to perform weight scales transformation. This function handles both SM90 and SM100 internally and can be used commonly.
  • The DeepEP low latency dispatch supports dispatching the activation scales in packed E8M0 format. This PR enables that.

main:
main

PR:
PR

Benchmark

full benchmark numbers - link

server command : VLLM_ALL2ALL_BACKEND=${A2A} VLLM_USE_DEEP_GEMM=1 canhazgpu run -g2 -- vllm serve Qwen/Qwen3-30B-A3B-FP8 --trust-remote-code --tensor-parallel-size 1 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching --port 9010

Decode bench command : vllm bench serve --model Qwen/Qwen3-30B-A3B-FP8 --dataset-name random --num-prompts 128 --random-input-len 1 --random-output-len 1024 --request-rate 128 --ignore-eos --port 9010

Prefill bench command : vllm bench serve --model Qwen/Qwen3-30B-A3B-FP8 --dataset-name random --num-prompts 256 --random-input-len 8192 --random-output-len 1 --request-rate 256 --ignore-eos --port 9010 --backend vllm

B200 + deepep_low_latency + decode

<style type="text/css"></style>

  main PR
Peak output token throughput (tok/s) 9600 12028

B200 + deepep_high_throughput + prefill

<style type="text/css"></style>

  main PR
Total Token throughput (tok/s) 89736.95 91310.99

H100 + deepep_low_latency + decode

<style type="text/css"></style>

  main PR
Peak output token throughput (tok/s) 8408 8422

H100 + deepep_high_throughput + prefill

<style type="text/css"></style>

  main PR
Total Token throughput (tok/s) 73950.14 74095.51

Test Plan

Server command :
vllm serve Qwen/Qwen3-30B-A3B-FP8 --trust-remote-code --tensor-parallel-size ${tp_size} --data-parallel-size ${dp_size} --enable-expert-parallel --no-enable-prefix-caching --port 9010

Server config combinations (VLLM_ALL2ALL_BACKEND, VLLM_USE_DEEP_GEMM, dp_size, tp_size) ,

  1. (deepep_low_latency, 1, 2, 1)
  2. (deepep_high_throughput, 1, 2, 1)
  3. (deepep_low_latency, 0, 2, 1) // This PR touches fp8 weight loading when Deepgemm is enabled, this config tests regression
  4. (deepep_high_throughput, 0, 2, 1)
  5. (n/a, 1, 1, 2)
  6. (n/a, 0, 1, 2)

lm_eval command :

lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-30B-A3B-FP8,base_url=http://localhost:9010/v1/completions,num_concurrent=30,max_retries=3 --limit 100

Test Result

lm_eval produces desired results on the PR, on both H100 and B200.
example of a desired result,

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.87|±  |0.0338|
|     |       |strict-match    |     5|exact_match|↑  | 0.92|±  |0.0273|

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for DeepGEMM on B200 hardware by correctly handling weight and activation scales in the required E8M0 format. The changes refactor the scale transformation logic into a centralized function, which improves code structure and adds support for B200 while maintaining H100 compatibility. The MoE framework is also correctly extended to support packed activation scales. My review includes one minor fix for an incorrect logging statement.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)

# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yewentao256 for the changes to this file.
I have replaced hopper specific get_col_major_tma_aligned_tensor with a generic (h100 and b200)transform_sf_into_required_layout utility from deepgemm. PTAL! Thanks 🙌

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main change is that we convert get_col_major_tma_aligned_tensor to transform_sf_into_required_layout? What is the difference of these and why could we get perf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main difference is get_col_major_tma_aligned_tensor is Hopper specific and transform_sf_into_required_layout works for both Hopper and Blackwell. See https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/layout.hpp#L8

The perf comes from the fact that on blackwell we were not transforming the weight scales during model weight setup -- Now we do (via transform_sf_into_required_layout) it in model weight setup and deepgemm doesn't have to do it in every call.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM nice find. I just have the concern about applying the right ue8m0 format for both Hopper and Blackwell if the model requires it

"""
DeepGemm supports packed ue8m0 activation scales format in devices >= sm100
"""
return current_platform.is_device_capability(100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment doesn't match this line since "is" is ==
Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?

Copy link
Member

@yewentao256 yewentao256 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment doesn't match this line since "is" is ==
Updated the comment to == sm100, since deepgemm readme specifies sm100 explicitly. We can upgrade as needed.

Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?

IIUC, this the state of main :

let ws be a weight scales tensor of shape [X, 4096] and datatype float32

  • on Hopper and Blackwell - When we use DeepGemm, we always (for block fp8 models) cast the weight scales to UE8M0. but keep the weight scales in float32. i.e. each float32 value actually holds UE8M0 content. Look here. i.e. only the first byte of each float32 value will have the actual contents.
    ws will be of shape [X, 4096] and of datatype float32.

This PR:

  • on Hopper - We don't change the behaviour on Hopper. i.e. scales are cast to UE8M0 but the tensors are still in float32
  • on Blackwell - We first cast the scales to UE8M0 and represent the tensors in float32. We than use the transform_sf_into_required_layout() from deepgemm to pack the scales into an int32 tensor. i.e. ws will be of shape [x, 1024] and of datatype int32.

+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.
@yewentao256 I have added some benchmark and lm-eval numbers in the PR description.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find and great performance improvement! Thanks for the work!
A few thoughts

"""
DeepGemm supports packed ue8m0 activation scales format in devices >= sm100
"""
return current_platform.is_device_capability(100)
Copy link
Member

@yewentao256 yewentao256 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.

layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)

# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main change is that we convert get_col_major_tma_aligned_tensor to transform_sf_into_required_layout? What is the difference of these and why could we get perf?

Comment on lines -52 to -54
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure who adds this before, could you take a further look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know why is this removed? Is this because of MOE vs Gemm impl differences?
DeepGemm seems to always be enabled even when other MOE backends are enabled. We need to have a better check to identify the moe backend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example to run Flashinfer MOE we now need to run:
VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=latency VLLM_USE_DEEP_GEMM=0 python ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like #25895 adds it. @pavanimajety can you please take a look. Thanks.

Copy link
Contributor Author

@varun-sundar-rabindranath varun-sundar-rabindranath Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavanimajety sorry i missed your comment.

May I know why is this removed? Is this because of MOE vs Gemm impl differences?

I removed it in an effort to cleanup. I think this function should depend only deepgemm specific attributes / envs.

DeepGemm seems to always be enabled even when other MOE backends are enabled. We need to have a better check to identify the moe backend.

Yes, this is a problem. I have addressed this is fp8.py. Please take a look at comment https://github.com/vllm-project/vllm/pull/27897/files#r2487520906 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running,

VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND="latency"  canhazgpu run -g8 --  lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-R1,quantization=fp8,tensor_parallel_size=8,gpu_memory_utilization=0.90,add_bos_token=True --gen_kwargs temperature=0.0,max_gen_toks=32768 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size 200 --limit 1319

from #25895 and got,

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9613|±  |0.0053|
|     |       |strict-match    |     5|exact_match|↑  |0.9621|±  |0.0053|

Varun Sundar Rabindranath added 2 commits November 3, 2025 13:27
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
del layer.w13_input_scale
del layer.w2_input_scale

if is_deep_gemm_e8m0_used() and self.block_quant:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we perform weight requant and weight scale transformation based on is_deep_gemm_e8m0_used() and self.block_quant - However, this does not consider what Fp8MoeBackend is used. i.e. regardless of the backend, which could be,

    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6

we perform weight requant and scales transform if DeepGEMM is available. This seems like a bug and I have moved this logic above and guarded with the self.allow_deep_gemm flag that is True only when the FP8MoeBackend is DEEPGEMM.

@yewentao256 - This block was first introduced in #20087 . Can you confirm if this is okay. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants