11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from collections .abc import Iterable
3+ import typing
4+ from collections .abc import Callable , Iterable
45
56import torch
67import torch .nn as nn
78from transformers import PretrainedConfig
89
910from vllm .compilation .decorators import support_torch_compile
1011from vllm .config import VllmConfig
11- from vllm .model_executor .layers .fused_moe import FusedMoE
12+ from vllm .model_executor .layers .fused_moe import SharedFusedMoE
13+ from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
14+ is_rocm_aiter_fusion_shared_expert_enabled ,
15+ )
1216from vllm .model_executor .layers .layernorm import RMSNorm
1317from vllm .model_executor .layers .logits_processor import LogitsProcessor
1418from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -212,11 +216,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
212216 ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
213217 ]
214218
215- expert_params_mapping = FusedMoE .make_expert_params_mapping (
219+ expert_params_mapping = SharedFusedMoE .make_expert_params_mapping (
216220 ckpt_gate_proj_name = "gate_proj" ,
217221 ckpt_down_proj_name = "down_proj" ,
218222 ckpt_up_proj_name = "up_proj" ,
219- num_experts = self .config .n_routed_experts ,
223+ num_experts = self .config .n_routed_experts
224+ + (
225+ self .config .n_shared_experts
226+ if is_rocm_aiter_fusion_shared_expert_enabled ()
227+ else 0
228+ ),
220229 )
221230
222231 params_dict = dict (self .named_parameters ())
@@ -227,6 +236,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
227236 spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
228237 if spec_layer is None :
229238 continue
239+ is_fuse_shared_experts_layer = (
240+ is_rocm_aiter_fusion_shared_expert_enabled ()
241+ and ("mlp.shared_experts" in name )
242+ )
230243 name = self ._rewrite_spec_layer_name (spec_layer , name )
231244 for param_name , weight_name , shard_id in stacked_params_mapping :
232245 # Skip non-stacked layers and experts (experts handled below).
@@ -240,6 +253,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
240253 # for mlp.experts[0].gate_gate_up_proj, which breaks load.
241254 if ("mlp.experts." in name ) and name not in params_dict :
242255 continue
256+ if is_fuse_shared_experts_layer :
257+ continue
243258 name_mapped = name .replace (weight_name , param_name )
244259
245260 # QKV fusion is optional, fall back to normal
@@ -260,45 +275,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260275 weight_loader (param , loaded_weight , shard_id )
261276 break
262277 else :
263- for mapping in expert_params_mapping :
264- param_name , weight_name , expert_id , shard_id = mapping
265- if weight_name not in name :
266- continue
267- name = name .replace (weight_name , param_name )
268-
269- param = params_dict [name ]
270- weight_loader = param .weight_loader
271- weight_loader (
272- param ,
273- loaded_weight ,
274- name ,
275- shard_id = shard_id ,
276- expert_id = expert_id ,
277- )
278- break
279- else :
280- # Skip loading extra bias for GPTQ models.
281- if name .endswith (".bias" ) and name not in params_dict :
282- continue
283-
284- name = maybe_remap_kv_scale_name (name , params_dict )
285- if name is None :
286- continue
287-
288- # According to DeepSeek-V3 Technical Report, MTP modules
289- # shares embedding layer. We only load the first weights.
290- if (
291- spec_layer != self .model .mtp_start_layer_idx
292- and ".layers" not in name
293- ):
294- continue
295-
296- param = params_dict [name ]
297- weight_loader = getattr (
298- param , "weight_loader" , default_weight_loader
278+ # Special handling: when AITER fusion_shared_experts is enabled,
279+ # checkpoints may provide a single widened shared_experts tensor
280+ # without explicit expert indices
281+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
282+ # For models with multiple shared experts, split that tensor
283+ # evenly into per-shared-expert slices and load them into
284+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
285+ # accordingly.
286+ num_chunks = 1
287+ if is_fuse_shared_experts_layer :
288+ num_chunks = getattr (self .config , "n_shared_experts" , 1 ) or 1
289+ # Determine split axis based on op type
290+ # gate/up: ColumnParallel → split along dim 0
291+ # down: RowParallel → split along dim 1
292+ split_dim = 1 if "down_proj.weight" in name else 0
293+ total = loaded_weight .shape [split_dim ]
294+ assert total % num_chunks == 0 , (
295+ f"Shared expert weight dim { total } "
296+ f"not divisible by num_chunks { num_chunks } "
299297 )
300- weight_loader (param , loaded_weight )
301- loaded_params .add (name )
298+ chunk_size = total // num_chunks
299+
300+ for j in range (num_chunks ):
301+ chunk_name = name
302+ weight_to_load = loaded_weight
303+
304+ if is_fuse_shared_experts_layer :
305+ if split_dim == 0 :
306+ weight_to_load = loaded_weight [
307+ j * chunk_size : (j + 1 ) * chunk_size , :
308+ ]
309+ else :
310+ weight_to_load = loaded_weight [
311+ :, j * chunk_size : (j + 1 ) * chunk_size
312+ ]
313+ # Synthesize an expert-style name so expert mapping
314+ # can route it
315+ chunk_name = name .replace (
316+ "mlp.shared_experts" ,
317+ f"mlp.experts.{ self .config .n_routed_experts + j } " ,
318+ )
319+
320+ # Use expert_params_mapping to locate the destination
321+ # param and delegate to its expert-aware weight_loader
322+ # with expert_id.
323+ for mapping in expert_params_mapping :
324+ param_name , weight_name , expert_id , shard_id = mapping
325+ if weight_name not in chunk_name :
326+ continue
327+
328+ # Do not modify `name` since the loop may continue here
329+ # Instead, create a new variable
330+ name_mapped = chunk_name .replace (weight_name , param_name )
331+
332+ param = params_dict [name_mapped ]
333+ # We should ask the weight loader to return success or
334+ # not here since otherwise we may skip experts with
335+ # other available replicas.
336+ weight_loader = typing .cast (
337+ Callable [..., bool ], param .weight_loader
338+ )
339+ success = weight_loader (
340+ param ,
341+ weight_to_load ,
342+ name_mapped ,
343+ shard_id = shard_id ,
344+ expert_id = expert_id ,
345+ return_success = True ,
346+ )
347+ if success :
348+ if not is_fuse_shared_experts_layer :
349+ name = name_mapped
350+ else :
351+ loaded_params .add (name_mapped )
352+ break
353+ else :
354+ # Skip loading extra bias for GPTQ models.
355+ if name .endswith (".bias" ) and name not in params_dict :
356+ continue
357+
358+ name = maybe_remap_kv_scale_name (name , params_dict )
359+ if name is None :
360+ continue
361+
362+ # According to DeepSeek-V3 Technical Report, MTP modules
363+ # shares embedding layer. We only load the first weights.
364+ if (
365+ spec_layer != self .model .mtp_start_layer_idx
366+ and ".layers" not in name
367+ ):
368+ continue
369+
370+ param = params_dict [name ]
371+ weight_loader = getattr (
372+ param , "weight_loader" , default_weight_loader
373+ )
374+ weight_loader (param , loaded_weight )
375+ if not is_fuse_shared_experts_layer :
376+ loaded_params .add (name )
302377 return loaded_params
303378
304379 def _rewrite_spec_layer_name (self , spec_layer : int , name : str ) -> str :
0 commit comments