Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(
"""
super().__init__(quant_config)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers

Expand All @@ -249,6 +250,12 @@ def supports_chunking(self) -> bool:
def supports_expert_map(self) -> bool:
return False

def supports_packed_ue8m0_act_scales(self) -> bool:
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return current_platform.is_device_capability(100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment doesn't match this line since "is" is ==
Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?

Copy link
Member

@yewentao256 yewentao256 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.

Copy link
Contributor Author

@varun-sundar-rabindranath varun-sundar-rabindranath Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment doesn't match this line since "is" is ==
Updated the comment to == sm100, since deepgemm readme specifies sm100 explicitly. We can upgrade as needed.

Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?

IIUC, this the state of main :

let ws be a weight scales tensor of shape [X, 4096] and datatype float32

  • on Hopper and Blackwell - When we use DeepGemm, we always (for block fp8 models) cast the weight scales to UE8M0. but keep the weight scales in float32. i.e. each float32 value actually holds UE8M0 content. Look here. i.e. only the first byte of each float32 value will have the actual contents.

[EDIT]The stricken out portion was wrong. We actually cast the weights to ue8m0 and then expand it back to float32 - effectively the scale values can be one of {2^i where i in [-127, 127]}

ws will be of shape [X, 4096] and of datatype float32.

This PR:

  • on Hopper - We don't change the behaviour on Hopper.
  • on Blackwell - We requant to UE8M0 and then we use the transform_sf_into_required_layout() from deepgemm to pack the scales into an int32 tensor. i.e. ws will be of shape [x, 1024] and of datatype int32. Effectively the scale values can be one of {i where in [-127, 127]}

+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.
@yewentao256 I have added some benchmark and lm-eval numbers in the PR description.


def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
Expand All @@ -20,6 +21,8 @@
dbo_maybe_run_recv_hook,
)

logger = init_logger(__name__)

# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
Expand Down Expand Up @@ -94,6 +97,28 @@ def __init__(
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers

# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by setup_packed_ue8m0_scales_dispatch.
self.use_ue8m0 = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is this flag only be used in low latency dispatch, and weight requant doesn't require this ?


def supports_packed_ue8m0_scales_dispatch(self) -> bool:
return True

def setup_packed_ue8m0_scales_dispatch(self) -> None:
if self.use_fp8_dispatch:
logger.debug_once(
"Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch"
)
self.use_ue8m0 = True
else:
logger.warning_once(
"Ignoring request to dispatch activation scales in a packed "
"ue8m0 format as DeepEPLLPrepareAndFinalize is setup to"
"dispatch raw/unquantized activations.",
scope="local",
)

def num_dispatchers(self) -> int:
return self.num_dispatchers_

Expand Down Expand Up @@ -206,6 +231,9 @@ def prepare_async(
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
# round_scale needs to be set to dispatch in ue8m0
round_scale=self.use_ue8m0,
use_ue8m0=self.use_ue8m0,
async_finish=False,
return_recv_hook=True,
)
Expand Down
34 changes: 34 additions & 0 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,20 @@ def output_is_reduced(self) -> bool:
"""
raise NotImplementedError

def supports_packed_ue8m0_scales_dispatch(self) -> bool:
"""
Return true if the implementation can dispatch activation scales in
packed ue8m0 format.
"""
return False

def setup_packed_ue8m0_scales_dispatch(self) -> None:
"""
Setup internal state of the implementation to dispatch activation scales
in packed ue8m0 format.
"""
raise NotImplementedError


# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC):
Expand Down Expand Up @@ -503,6 +517,13 @@ def supports_expert_map(self) -> bool:
"""
raise NotImplementedError

def supports_packed_ue8m0_act_scales(self) -> bool:
"""
A flag indicating whether or not this class can process packed ue8m0
activation scales.
"""
return False

def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
"""
Workspace type: The dtype to use for the workspace tensors.
Expand Down Expand Up @@ -698,6 +719,8 @@ def __init__(
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts

self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
), (
Expand All @@ -707,6 +730,17 @@ def __init__(
f"{fused_experts.activation_formats[0]}"
)

def _post_init_setup(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this method be overridden by subclasses? If so, then you could avoid adding ue8m0 related methods to all the modular kernels and make an implementation specific to DeepEPLL.../BatchedDeepGemmExperts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. _post_init_setup is in FusedMoEModularKernel class which doesn't have any subclasses. can you clarify. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I was thinking you could have a generic _post_init_setup sort of method on prepare_finalize so that we wouldn't need to put such specific methods/logic here, e.g.

def _post_init_setup(self):
    self.prepare_finalize._post_init_setup(self.fused_experts)

And _post_init_setup will be a nop for everything except for DeepEPLL...

"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
if (
self.fused_experts.supports_packed_ue8m0_act_scales()
and self.prepare_finalize.supports_packed_ue8m0_scales_dispatch()
):
self.prepare_finalize.setup_packed_ue8m0_scales_dispatch()

def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
Expand Down
74 changes: 34 additions & 40 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
expert_weight_is_col_major,
deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
Expand Down Expand Up @@ -95,7 +94,6 @@
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
Expand Down Expand Up @@ -918,15 +916,31 @@ def process_weights_after_loading(self, layer: Module) -> None:

# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
if self.allow_deep_gemm:
dg_w13_weight, dg_w13_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w13_weight.data,
ws=layer.w13_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
dg_w2_weight, dg_w2_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w2_weight.data,
ws=layer.w2_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
)
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
dg_w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)

# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down Expand Up @@ -1062,31 +1076,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
del layer.w13_input_scale
del layer.w2_input_scale

if is_deep_gemm_e8m0_used() and self.block_quant:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we perform weight requant and weight scale transformation based on is_deep_gemm_e8m0_used() and self.block_quant - However, this does not consider what Fp8MoeBackend is used. i.e. regardless of the backend, which could be,

    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6

we perform weight requant and scales transform if DeepGEMM is available. This seems like a bug and I have moved this logic above and guarded with the self.allow_deep_gemm flag that is True only when the FP8MoeBackend is DEEPGEMM.

@yewentao256 - This block was first introduced in #20087 . Can you confirm if this is okay. Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense to me, where is the code you have for the guard?

assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)

# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)

def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
Expand All @@ -1109,7 +1098,8 @@ def select_gemm_impl(
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts,
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonOrDeepGemmExperts,
)

Expand All @@ -1125,20 +1115,24 @@ def select_gemm_impl(
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None

experts_impl = (
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
)
logger.debug(
"BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
experts_impl.__name__,
self.__class__.__name__,
max_num_tokens_per_rank,
self.weight_block_size,
False,
)
return BatchedTritonOrDeepGemmExperts(
return experts_impl(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)

elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
self.moe,
Expand Down
58 changes: 54 additions & 4 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout,
)
from vllm.utils.torch_utils import direct_register_custom_op

Expand Down Expand Up @@ -938,6 +939,50 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)


def deepgemm_post_process_fp8_weight_block(
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
assert wq.dtype == torch.float8_e4m3fn, (
"Expected quantized tensor dtype "
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
)
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
)

if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)

original_ndim = wq.ndim
if wq.ndim == 2:
assert ws.ndim == 2
wq = wq.unsqueeze(0)
ws = ws.unsqueeze(0)

# From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46
recipe = (1, 128, 128)

# Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp
# DeepGemm uses the `transform_sf_into_required_layout` function to
# represent scales in the correct format.
dg_ws = transform_sf_into_required_layout(
sf=ws,
mn=wq.size(1),
k=wq.size(2),
recipe=recipe,
num_groups=wq.size(0),
# is the scale factors for A in (Refers to the argument A in A @ B).
# Weights are B.
is_sfa=False,
)

if original_ndim == 2:
wq = wq.squeeze(0)
dg_ws = dg_ws.squeeze(0)

return wq, dg_ws


def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
Expand Down Expand Up @@ -1163,11 +1208,16 @@ def maybe_post_process_fp8_weight_block(
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight
)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=layer.weight_scale.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)

# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (
current_platform.is_device_capability(90)
Expand Down
Loading