Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def _validate_parallel_config(self) -> Self:
)

if self.enable_eplb:
if not current_platform.is_cuda():
if not current_platform.is_cuda_alike():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
"CUDA devices or ROCm devices now."
)
if not self.enable_expert_parallel:
raise ValueError("enable_expert_parallel must be True to use EPLB.")
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,11 @@ def load_weights(

def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
assert all(weight.is_contiguous() for _, weight in weights)
assert all(
weight.is_contiguous()
for name, weight in weights
if not name.startswith("_shared_experts.")
)

# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

Original file line number Diff line number Diff line change
Expand Up @@ -1019,9 +1019,10 @@ def apply(
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
)
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)

topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -1037,6 +1038,11 @@ def apply(
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=layer.num_fused_shared_experts,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
Expand Down Expand Up @@ -1145,6 +1151,10 @@ def apply(
quant_config=self.moe_quant_config,
)

@property
def supports_eplb(self) -> bool:
return True


class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
Expand Down