Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None

self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
Expand All @@ -450,14 +445,28 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
if current_platform.is_rocm():
self.use_marlin = False

# First check for Flashinfer MOE on Blackwell GPUs
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
if (current_platform.is_cuda() and
current_platform.is_device_capability(100) and
envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Detected Blackwell GPUs, using FlashInfer "
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE."
)

# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"DeepGemm kernels")
logger.warning_once("Model is not block quantized. Not using"
" DeepGemm kernels")
elif self.flashinfer_moe_backend:
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
" enabled.")
Comment on lines 463 to +468
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future self: we should clean up these logs now that VLLM_USE_DEEP_GEMM=1 by default

elif (is_deep_gemm_supported()):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
Expand All @@ -471,15 +480,13 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
logger.debug_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.is_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
and current_platform.is_device_capability(100)
and not self.flashinfer_moe_backend):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
"on SM100."
)
self.allow_cutlass_block_scaled_grouped_gemm = True

def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
Expand Down Expand Up @@ -934,7 +941,9 @@ def apply(
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
e_score_correction_bias = (e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None else None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is a great fix that addresses two important issues:

  1. Casting e_score_correction_bias to x.dtype resolves a potential dtype mismatch that could lead to accuracy problems, which is critical for correctness.
  2. Adding the return statement corrects a significant control flow bug. Previously, the code would fall through and execute select_experts and other logic even after the complete MoE operation was performed by flashinfer_fused_moe_blockscale_fp8. This early return ensures the function exits correctly.

This change is crucial for both correctness and logic of the FP8 MoE path.

routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
Expand Down
7 changes: 6 additions & 1 deletion vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool:
is_supported_arch = current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability(100))
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and
is_supported_arch and not envs.VLLM_USE_FLASHINFER_MOE_FP8)


@functools.cache
Expand All @@ -45,6 +46,10 @@ def is_deep_gemm_e8m0_used() -> bool:
if _fp8_gemm_nt_impl is None:
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False

if envs.VLLM_USE_FLASHINFER_MOE_FP8:
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
return False

if current_platform.is_device_capability(100) and \
envs.VLLM_USE_DEEP_GEMM_E8M0:
Expand Down
Loading