-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Performance][B200] Fix deepgemm prologue #27897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import torch | ||
|
|
||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig | ||
| from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( | ||
| TopKWeightAndReduceDelegate, | ||
|
|
@@ -20,6 +21,8 @@ | |
| dbo_maybe_run_recv_hook, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # DeepEP kernels quantize dispatch inputs in 128 element chunks. | ||
| DEEPEP_QUANT_BLOCK_SIZE = 128 | ||
| DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] | ||
|
|
@@ -94,6 +97,28 @@ def __init__( | |
| self.handles: list[tuple | None] = [None, None] | ||
| self.num_dispatchers_ = num_dispatchers | ||
|
|
||
| # We don't have enough information to determine if we should dispatch | ||
| # activation scales in a packed ue8m0 format during object construction | ||
| # time. This setting is handled by setup_packed_ue8m0_scales_dispatch. | ||
| self.use_ue8m0 = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is this flag only be used in low latency dispatch, and weight requant doesn't require this ? |
||
|
|
||
| def supports_packed_ue8m0_scales_dispatch(self) -> bool: | ||
| return True | ||
|
|
||
| def setup_packed_ue8m0_scales_dispatch(self) -> None: | ||
| if self.use_fp8_dispatch: | ||
| logger.debug_once( | ||
| "Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch" | ||
| ) | ||
| self.use_ue8m0 = True | ||
| else: | ||
| logger.warning_once( | ||
| "Ignoring request to dispatch activation scales in a packed " | ||
| "ue8m0 format as DeepEPLLPrepareAndFinalize is setup to" | ||
| "dispatch raw/unquantized activations.", | ||
| scope="local", | ||
| ) | ||
|
|
||
| def num_dispatchers(self) -> int: | ||
| return self.num_dispatchers_ | ||
|
|
||
|
|
@@ -206,6 +231,9 @@ def prepare_async( | |
| self.max_tokens_per_rank, | ||
| num_experts, | ||
| use_fp8=self.use_fp8_dispatch, | ||
| # round_scale needs to be set to dispatch in ue8m0 | ||
| round_scale=self.use_ue8m0, | ||
| use_ue8m0=self.use_ue8m0, | ||
| async_finish=False, | ||
| return_recv_hook=True, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -347,6 +347,20 @@ def output_is_reduced(self) -> bool: | |
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def supports_packed_ue8m0_scales_dispatch(self) -> bool: | ||
| """ | ||
| Return true if the implementation can dispatch activation scales in | ||
| packed ue8m0 format. | ||
| """ | ||
| return False | ||
|
|
||
| def setup_packed_ue8m0_scales_dispatch(self) -> None: | ||
| """ | ||
| Setup internal state of the implementation to dispatch activation scales | ||
| in packed ue8m0 format. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| # TODO: add supported activations method (return string) | ||
| class FusedMoEPermuteExpertsUnpermute(ABC): | ||
|
|
@@ -503,6 +517,13 @@ def supports_expert_map(self) -> bool: | |
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def supports_packed_ue8m0_act_scales(self) -> bool: | ||
| """ | ||
| A flag indicating whether or not this class can process packed ue8m0 | ||
| activation scales. | ||
| """ | ||
| return False | ||
|
|
||
| def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: | ||
| """ | ||
| Workspace type: The dtype to use for the workspace tensors. | ||
|
|
@@ -698,6 +719,8 @@ def __init__( | |
| self.prepare_finalize = prepare_finalize | ||
| self.fused_experts = fused_experts | ||
| self.shared_experts = shared_experts | ||
|
|
||
| self._post_init_setup() | ||
| assert ( | ||
| prepare_finalize.activation_format == fused_experts.activation_formats[0] | ||
| ), ( | ||
|
|
@@ -707,6 +730,17 @@ def __init__( | |
| f"{fused_experts.activation_formats[0]}" | ||
| ) | ||
|
|
||
| def _post_init_setup(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this method be overridden by subclasses? If so, then you could avoid adding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I was thinking you could have a generic And |
||
| """ | ||
| Resolve any leftover setup dependencies between self.prepare_finalize | ||
| and self.fused_experts here. | ||
| """ | ||
| if ( | ||
| self.fused_experts.supports_packed_ue8m0_act_scales() | ||
| and self.prepare_finalize.supports_packed_ue8m0_scales_dispatch() | ||
| ): | ||
| self.prepare_finalize.setup_packed_ue8m0_scales_dispatch() | ||
|
|
||
| def supports_expert_map(self) -> bool: | ||
| """ | ||
| A flag indicating whether or not this class supports expert maps. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,10 @@ | |
| create_fp8_input_scale, | ||
| create_fp8_scale_parameter, | ||
| create_fp8_weight_parameter, | ||
| expert_weight_is_col_major, | ||
| deepgemm_post_process_fp8_weight_block, | ||
| maybe_post_process_fp8_weight_block, | ||
| process_fp8_weight_block_strategy, | ||
| process_fp8_weight_tensor_strategy, | ||
| requant_weight_ue8m0_inplace, | ||
| validate_fp8_block_shape, | ||
| ) | ||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( | ||
|
|
@@ -95,7 +94,6 @@ | |
| from vllm.scalar_type import scalar_types | ||
| from vllm.utils.deep_gemm import ( | ||
| fp8_gemm_nt, | ||
| get_col_major_tma_aligned_tensor, | ||
| is_deep_gemm_e8m0_used, | ||
| is_deep_gemm_supported, | ||
| should_use_deepgemm_for_fp8_linear, | ||
|
|
@@ -918,15 +916,31 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
|
|
||
| # DeepGemm scales need to be transposed and aligned. We try to do | ||
| # it ahead of time for performance reasons. | ||
yewentao256 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): | ||
| if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||
| layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||
| layer.w13_weight_scale_inv | ||
| if self.allow_deep_gemm: | ||
| dg_w13_weight, dg_w13_weight_scale_inv = ( | ||
| deepgemm_post_process_fp8_weight_block( | ||
| wq=layer.w13_weight.data, | ||
| ws=layer.w13_weight_scale_inv.data, | ||
| quant_block_shape=tuple(layer.weight_block_size), | ||
| use_e8m0=is_deep_gemm_e8m0_used(), | ||
| ) | ||
| if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||
| layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||
| layer.w2_weight_scale_inv | ||
| ) | ||
| dg_w2_weight, dg_w2_weight_scale_inv = ( | ||
| deepgemm_post_process_fp8_weight_block( | ||
| wq=layer.w2_weight.data, | ||
| ws=layer.w2_weight_scale_inv.data, | ||
| quant_block_shape=tuple(layer.weight_block_size), | ||
| use_e8m0=is_deep_gemm_e8m0_used(), | ||
| ) | ||
| ) | ||
| layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False) | ||
| layer.w13_weight_scale_inv = Parameter( | ||
| dg_w13_weight_scale_inv, requires_grad=False | ||
| ) | ||
| layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False) | ||
| layer.w2_weight_scale_inv = Parameter( | ||
| dg_w2_weight_scale_inv, requires_grad=False | ||
| ) | ||
|
|
||
| # If checkpoint is fp16, quantize in place. | ||
| elif not self.quant_config.is_checkpoint_fp8_serialized: | ||
|
|
@@ -1062,31 +1076,6 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
| del layer.w13_input_scale | ||
| del layer.w2_input_scale | ||
|
|
||
| if is_deep_gemm_e8m0_used() and self.block_quant: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we perform weight requant and weight scale transformation based on we perform weight requant and scales transform if DeepGEMM is available. This seems like a bug and I have moved this logic above and guarded with the @yewentao256 - This block was first introduced in #20087 . Can you confirm if this is okay. Thanks.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense to me, where is the code you have for the guard? |
||
| assert layer.weight_block_size is not None | ||
| # Re-quantise the expert weights so their scales are UE8M0. | ||
| block_sz = tuple(layer.weight_block_size) | ||
| requant_weight_ue8m0_inplace( | ||
| layer.w13_weight.data, | ||
| layer.w13_weight_scale_inv.data, | ||
| block_sz, | ||
| ) | ||
| requant_weight_ue8m0_inplace( | ||
| layer.w2_weight.data, | ||
| layer.w2_weight_scale_inv.data, | ||
| block_sz, | ||
| ) | ||
|
|
||
| # Ensure column-major TMA alignment expected by DeepGEMM. | ||
| if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||
| layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||
| layer.w13_weight_scale_inv | ||
| ) | ||
| if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||
| layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||
| layer.w2_weight_scale_inv | ||
| ) | ||
|
|
||
| def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: | ||
| if ( | ||
| self.rocm_aiter_moe_enabled | ||
|
|
@@ -1109,7 +1098,8 @@ def select_gemm_impl( | |
| layer: torch.nn.Module, | ||
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> FusedMoEPermuteExpertsUnpermute: | ||
| from vllm.model_executor.layers.fused_moe import ( | ||
| BatchedTritonOrDeepGemmExperts, | ||
| BatchedDeepGemmExperts, | ||
| BatchedTritonExperts, | ||
| TritonOrDeepGemmExperts, | ||
| ) | ||
|
|
||
|
|
@@ -1125,20 +1115,24 @@ def select_gemm_impl( | |
| ): | ||
| max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() | ||
| assert max_num_tokens_per_rank is not None | ||
|
|
||
| experts_impl = ( | ||
| BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts | ||
| ) | ||
| logger.debug( | ||
| "BatchedTritonOrDeepGemmExperts(%s): " | ||
| "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", | ||
| "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", | ||
| experts_impl.__name__, | ||
| self.__class__.__name__, | ||
| max_num_tokens_per_rank, | ||
| self.weight_block_size, | ||
| False, | ||
| ) | ||
| return BatchedTritonOrDeepGemmExperts( | ||
| return experts_impl( | ||
| max_num_tokens=max_num_tokens_per_rank, | ||
| num_dispatchers=prepare_finalize.num_dispatchers(), | ||
| quant_config=self.moe_quant_config, | ||
| allow_deep_gemm=self.allow_deep_gemm, | ||
| ) | ||
|
|
||
| elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: | ||
| experts = select_cutlass_fp8_gemm_impl( | ||
| self.moe, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment doesn't match this line since "is" is ==
Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this the state of
main:let
wsbe a weight scales tensor of shape[X, 4096]and datatypefloat32but keep the weight scales infloat32. i.e. each float32 value actually holds UE8M0 content. Look here. i.e. only the first byte of each float32 value will have the actual contents.[EDIT]The stricken out portion was wrong. We actually cast the weights to ue8m0 and then expand it back to float32 - effectively the scale values can be one of
{2^i where i in [-127, 127]}wswill be of shape[X, 4096]and of datatypefloat32.This PR:
UE8M0and then we use thetransform_sf_into_required_layout()from deepgemm to pack the scales into an int32 tensor. i.e.wswill be of shape[x, 1024]and of datatypeint32. Effectively the scale values can be one of{i where in [-127, 127]}