Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON

Expand Down Expand Up @@ -42,6 +43,7 @@ def get_config() -> Optional[dict[str, Any]]:
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"SharedFusedMoE",
"activation_without_mul",
"override_config",
"get_config",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):

def __init__(
self,
shared_experts: torch.nn.Module,
shared_experts: Optional[torch.nn.Module],
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and self._shared_experts is not None
)

@property
def shared_experts(self) -> Optional[torch.nn.Module]:
Expand All @@ -36,16 +44,19 @@ def forward(
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)

# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)

# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None

fused_out = super().forward(
hidden_states=hidden_states,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,8 @@ def create_weights(
layer.w13_input_scale = None
layer.w2_input_scale = None

self.rocm_aiter_moe_enabled = False

def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/layers/shared_fused_moe/__init__.py

This file was deleted.

28 changes: 15 additions & 13 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
return out


class AriaFusedMoE(FusedMoE):
class AriaFusedMoE(SharedFusedMoE):
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None:
Expand Down Expand Up @@ -260,7 +260,16 @@ def __init__(
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
)

self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)

self.experts = AriaFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
Expand All @@ -269,13 +278,6 @@ def __init__(
reduce_results=True,
prefix=f"{prefix}.experts",
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -291,12 +293,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

router_output = torch.nn.functional.linear(hidden_states, self.router_weight)

hidden_states_copy = hidden_states.clone()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states_copy)

return sparse_expert_output + shared_expert_output
if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output


class AriaTextDecoderLayer(LlamaDecoderLayer):
Expand Down
49 changes: 28 additions & 21 deletions vllm/model_executor/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -276,22 +276,6 @@ def __init__(
# default value for scoring_func
self.score_function = "softmax"

self.experts = FusedMoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)

if self.num_shared_experts > 0:
if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size
Expand All @@ -308,11 +292,29 @@ def __init__(
else:
self.shared_experts = None

self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
fused_output_scaling_factor=self.routed_scaling_factor,
shared_output_scaling_factor=1.0,
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.shared_experts:
shared_output = self.shared_experts(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype)
Expand All @@ -321,9 +323,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states=hidden_states, router_logits=router_logits
)

if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None

final_hidden_states *= self.routed_scaling_factor

if self.shared_experts:
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output

if self.tp_size > 1:
Expand Down Expand Up @@ -475,7 +482,7 @@ def forward(
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping(
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
Expand Down
69 changes: 24 additions & 45 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -64,7 +64,6 @@
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -205,26 +204,6 @@ def __init__(
)

if config.n_shared_experts is None:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
Expand All @@ -239,27 +218,27 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -1293,7 +1272,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group

self.moe_layers: list[FusedMoE] = []
self.moe_layers: list[SharedFusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
Expand Down Expand Up @@ -1381,7 +1360,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
Expand Down
Loading