diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b19c8beeae3d..61bcd15e06a8 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -278,10 +278,10 @@ def _validate_parallel_config(self) -> Self: ) if self.enable_eplb: - if not current_platform.is_cuda(): + if not current_platform.is_cuda_alike(): raise ValueError( "Expert parallelism load balancing is only supported on " - "CUDA devices now." + "CUDA devices or ROCm devices now." ) if not self.enable_expert_parallel: raise ValueError("enable_expert_parallel must be True to use EPLB.") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 615da58eeda2..3bd7c54c520c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1218,7 +1218,11 @@ def load_weights( def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) - assert all(weight.is_contiguous() for _, weight in weights) + assert all( + weight.is_contiguous() + for name, weight in weights + if not name.startswith("_shared_experts.") + ) # Filter out the non-expert weights. # `e_score_correction_bias` is a bias for each logical expert, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py index e69de29bb2d1..6655f8913623 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f1050c15f79e..bda94cee9e42 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1019,9 +1019,10 @@ def apply( logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet." - ) + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -1037,6 +1038,11 @@ def apply( e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, num_fused_shared_experts=layer.num_fused_shared_experts, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN @@ -1145,6 +1151,10 @@ def apply( quant_config=self.moe_quant_config, ) + @property + def supports_eplb(self) -> bool: + return True + class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__(