-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Bugfix] Fix accuracy issue of TRTLLM FP8 MOE and improve logging #25895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -434,14 +434,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): | |
| self.weight_block_size = self.quant_config.weight_block_size | ||
| self.block_quant = self.weight_block_size is not None | ||
|
|
||
| self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None | ||
| self.fused_experts: Optional[ | ||
| mk.FusedMoEModularKernel] = None # type: ignore | ||
| if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): | ||
| self.flashinfer_moe_backend = get_flashinfer_moe_backend() | ||
| logger.info_once( | ||
| f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" | ||
| ) | ||
|
|
||
| # For GPUs that lack FP8 hardware support, we can leverage the Marlin | ||
| # kernel for fast weight-only FP8 quantization | ||
| self.use_marlin = (not current_platform.has_device_capability(89) | ||
|
|
@@ -450,14 +445,28 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): | |
| if current_platform.is_rocm(): | ||
| self.use_marlin = False | ||
|
|
||
| # First check for Flashinfer MOE on Blackwell GPUs | ||
| self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None | ||
| if (current_platform.is_cuda() and | ||
| current_platform.is_device_capability(100) and | ||
| envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): | ||
| self.flashinfer_moe_backend = get_flashinfer_moe_backend() | ||
| logger.info_once( | ||
| f"Detected Blackwell GPUs, using FlashInfer " | ||
| f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE." | ||
| ) | ||
|
|
||
| # Check for DeepGemm support. | ||
| self.allow_deep_gemm = False | ||
| if envs.VLLM_USE_DEEP_GEMM: | ||
| if not has_deep_gemm(): | ||
| logger.warning_once("Failed to import DeepGemm kernels.") | ||
| elif not self.block_quant: | ||
| logger.warning_once("Model is not block quantized. Not using " | ||
| "DeepGemm kernels") | ||
| logger.warning_once("Model is not block quantized. Not using" | ||
| " DeepGemm kernels") | ||
| elif self.flashinfer_moe_backend: | ||
| logger.info_once("DeepGemm disabled: FlashInfer MOE is" | ||
| " enabled.") | ||
| elif (is_deep_gemm_supported()): | ||
| logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") | ||
| self.allow_deep_gemm = True | ||
|
|
@@ -471,15 +480,13 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): | |
| logger.debug_once("Model is not block quantized. Not using " | ||
| "CutlassBlockScaledGroupedGemm kernels") | ||
| elif (current_platform.is_cuda() | ||
| and current_platform.is_device_capability(100)): | ||
| logger.info_once( | ||
| "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." | ||
| ) | ||
| self.allow_cutlass_block_scaled_grouped_gemm = True | ||
| else: | ||
| logger.warning_once( | ||
| "CutlassBlockScaledGroupedGemm not supported on the current " | ||
| "platform.") | ||
| and current_platform.is_device_capability(100) | ||
| and not self.flashinfer_moe_backend): | ||
| logger.info_once( | ||
| "Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE " | ||
| "on SM100." | ||
| ) | ||
| self.allow_cutlass_block_scaled_grouped_gemm = True | ||
|
|
||
| def create_weights(self, layer: Module, num_experts: int, hidden_size: int, | ||
| intermediate_size_per_partition: int, | ||
|
|
@@ -934,7 +941,9 @@ def apply( | |
| import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 | ||
| assert (renormalize and use_grouped_topk | ||
| and custom_routing_function is None) | ||
| result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( | ||
| e_score_correction_bias = (e_score_correction_bias.to(x.dtype) | ||
| if e_score_correction_bias is not None else None) | ||
| return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( | ||
|
||
| routing_logits=router_logits.to(torch.float32), | ||
| routing_bias=e_score_correction_bias, | ||
| x=x, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for future self: we should clean up these logs now that VLLM_USE_DEEP_GEMM=1 by default