From 51eb8c169e47c43798889b2ef196c4d5fad0bdab Mon Sep 17 00:00:00 2001 From: Perry Zhang Date: Wed, 29 Oct 2025 05:54:51 +0000 Subject: [PATCH 1/5] [EPLB][ROCm]: support EPBL for ROCm backend Signed-off-by: Perry Zhang --- vllm/config/parallel.py | 4 ++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../compressed_tensors/compressed_tensors_moe.py | 12 +++++++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 82d575f24690..79ee9eb095e6 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -267,10 +267,10 @@ def _validate_parallel_config(self) -> Self: ) if self.enable_eplb: - if not current_platform.is_cuda(): + if not (current_platform.is_cuda() or current_platform.is_rocm()) : raise ValueError( "Expert parallelism load balancing is only supported on " - "CUDA devices now." + "CUDA devices or ROCm devices now." ) if not self.enable_expert_parallel: raise ValueError("enable_expert_parallel must be True to use EPLB.") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1236116386c9..4dc62c3b5c71 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2080,7 +2080,7 @@ def load_weights( def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) - assert all(weight.is_contiguous() for _, weight in weights) + # assert all(weight.is_contiguous() for _, weight in weights) # Filter out the non-expert weights. # `e_score_correction_bias` is a bias for each logical expert, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..27341414670c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1020,9 +1020,10 @@ def apply( logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet." - ) + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -1038,6 +1039,11 @@ def apply( e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, num_fused_shared_experts=layer.num_fused_shared_experts, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN From 1438a7707a7ba8b0af4e97e435eaad9bf71d4ba6 Mon Sep 17 00:00:00 2001 From: Perry Zhang Date: Wed, 29 Oct 2025 11:22:26 +0000 Subject: [PATCH 2/5] [EPLB](fix): fix assert error for shared experts Signed-off-by: Perry Zhang --- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4dc62c3b5c71..bd026b06783e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2080,23 +2080,22 @@ def load_weights( def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) - # assert all(weight.is_contiguous() for _, weight in weights) - # Filter out the non-expert weights. - # `e_score_correction_bias` is a bias for each logical expert, - # with shape (num_logical_experts,), not an expert weight. - NON_EXPERT_WEIGHTS = { - "e_score_correction_bias", + ROUTED_EXPERT_WEIGHTS = { + "w13_weight", + "w13_weight_scale", + "w2_weight", + "w2_weight_scale", } + + assert all(weight.is_contiguous() + for name, weight in weights + if name in ROUTED_EXPERT_WEIGHTS) return [ weight.view(self.local_num_experts, -1) for name, weight in weights - if name not in NON_EXPERT_WEIGHTS - and weight.shape != torch.Size([]) - and not name.startswith("_shared_experts.") - # exclude parameters from non-expert submodules (e.g. gate/shared) - and not name.startswith("_gate.") + if name in ROUTED_EXPERT_WEIGHTS ] def set_eplb_state( From 0101f5fe0e9935039b2c231cf94ec02849ec7c06 Mon Sep 17 00:00:00 2001 From: Perry Zhang Date: Wed, 29 Oct 2025 13:07:04 +0000 Subject: [PATCH 3/5] [EPLB][fix]: reuse weight filter for other models Signed-off-by: Perry Zhang --- vllm/config/parallel.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 25 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 79ee9eb095e6..a8712f7c8721 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -267,7 +267,7 @@ def _validate_parallel_config(self) -> Self: ) if self.enable_eplb: - if not (current_platform.is_cuda() or current_platform.is_rocm()) : + if not current_platform.is_cuda_alike(): raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices or ROCm devices now." diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bd026b06783e..13293a78e217 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2080,22 +2080,27 @@ def load_weights( def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) + assert all( + weight.is_contiguous() + for name, weight in weights + if not name.startswith("_shared_experts.") + ) - ROUTED_EXPERT_WEIGHTS = { - "w13_weight", - "w13_weight_scale", - "w2_weight", - "w2_weight_scale", + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", } - - assert all(weight.is_contiguous() - for name, weight in weights - if name in ROUTED_EXPERT_WEIGHTS) return [ weight.view(self.local_num_experts, -1) for name, weight in weights - if name in ROUTED_EXPERT_WEIGHTS + if name not in NON_EXPERT_WEIGHTS + and weight.shape != torch.Size([]) + and not name.startswith("_shared_experts.") + # exclude parameters from non-expert submodules (e.g. gate/shared) + and not name.startswith("_gate.") ] def set_eplb_state( From 5f7b78f868ba700b15eee28b33f21bdbaa955c10 Mon Sep 17 00:00:00 2001 From: Perry Zhang Date: Thu, 30 Oct 2025 11:55:39 +0000 Subject: [PATCH 4/5] [EPLB][typo]: modify import method for pre-commit format Signed-off-by: Perry Zhang --- .../compressed_tensors/__init__.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py index e69de29bb2d1..926cd994e9eb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .compressed_tensors import CompressedTensorsLinearMethod +from .compressed_tensors_moe import ( + CompressedTensorsMoEMethod, + CompressedTensorsW4A4MoeMethod, + CompressedTensorsW4A8Int8MoEMethod, + CompressedTensorsW8A8Fp8MoEMethod, + CompressedTensorsW8A8Int8MoEMethod, + CompressedTensorsWNA16MarlinMoEMethod, + CompressedTensorsWNA16MoEMethod, +) + +__all__ = [ + "CompressedTensorsLinearMethod", + "CompressedTensorsMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsW8A8Int8MoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod", + "CompressedTensorsW4A8Int8MoEMethod", +] From cabe8bacbb2404b7cad6088d1d5daa34fa312a0e Mon Sep 17 00:00:00 2001 From: Perry Zhang Date: Thu, 6 Nov 2025 08:57:46 +0000 Subject: [PATCH 5/5] [EPLB][fix]: compatible with the latest code Signed-off-by: Perry Zhang --- .../compressed_tensors/__init__.py | 20 ------------------- .../compressed_tensors_moe.py | 4 ++++ 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py index 926cd994e9eb..6655f8913623 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py @@ -1,23 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .compressed_tensors import CompressedTensorsLinearMethod -from .compressed_tensors_moe import ( - CompressedTensorsMoEMethod, - CompressedTensorsW4A4MoeMethod, - CompressedTensorsW4A8Int8MoEMethod, - CompressedTensorsW8A8Fp8MoEMethod, - CompressedTensorsW8A8Int8MoEMethod, - CompressedTensorsWNA16MarlinMoEMethod, - CompressedTensorsWNA16MoEMethod, -) -__all__ = [ - "CompressedTensorsLinearMethod", - "CompressedTensorsMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", - "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod", - "CompressedTensorsW4A8Int8MoEMethod", -] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 27341414670c..af78a075228c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1152,6 +1152,10 @@ def apply( quant_config=self.moe_quant_config, ) + @property + def supports_eplb(self) -> bool: + return True + class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__(