Skip to content

Commit 29f99f7

Browse files
zhyajiezhyajie
authored andcommitted
[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP
Signed-off-by: zhyajie <[email protected]>
1 parent cbd5e07 commit 29f99f7

File tree

1 file changed

+117
-42
lines changed

1 file changed

+117
-42
lines changed

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 117 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from collections.abc import Iterable
3+
import typing
4+
from collections.abc import Callable, Iterable
45

56
import torch
67
import torch.nn as nn
78
from transformers import PretrainedConfig
89

910
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import VllmConfig
11-
from vllm.model_executor.layers.fused_moe import FusedMoE
12+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
13+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
14+
is_rocm_aiter_fusion_shared_expert_enabled,
15+
)
1216
from vllm.model_executor.layers.layernorm import RMSNorm
1317
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1418
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -212,11 +216,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
212216
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
213217
]
214218

215-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
219+
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
216220
ckpt_gate_proj_name="gate_proj",
217221
ckpt_down_proj_name="down_proj",
218222
ckpt_up_proj_name="up_proj",
219-
num_experts=self.config.n_routed_experts,
223+
num_experts=self.config.n_routed_experts
224+
+ (
225+
self.config.n_shared_experts
226+
if is_rocm_aiter_fusion_shared_expert_enabled()
227+
else 0
228+
),
220229
)
221230

222231
params_dict = dict(self.named_parameters())
@@ -227,6 +236,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
227236
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
228237
if spec_layer is None:
229238
continue
239+
is_fuse_shared_experts_layer = (
240+
is_rocm_aiter_fusion_shared_expert_enabled()
241+
and ("mlp.shared_experts" in name)
242+
)
230243
name = self._rewrite_spec_layer_name(spec_layer, name)
231244
for param_name, weight_name, shard_id in stacked_params_mapping:
232245
# Skip non-stacked layers and experts (experts handled below).
@@ -240,6 +253,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
240253
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
241254
if ("mlp.experts." in name) and name not in params_dict:
242255
continue
256+
if is_fuse_shared_experts_layer:
257+
continue
243258
name_mapped = name.replace(weight_name, param_name)
244259

245260
# QKV fusion is optional, fall back to normal
@@ -260,45 +275,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260275
weight_loader(param, loaded_weight, shard_id)
261276
break
262277
else:
263-
for mapping in expert_params_mapping:
264-
param_name, weight_name, expert_id, shard_id = mapping
265-
if weight_name not in name:
266-
continue
267-
name = name.replace(weight_name, param_name)
268-
269-
param = params_dict[name]
270-
weight_loader = param.weight_loader
271-
weight_loader(
272-
param,
273-
loaded_weight,
274-
name,
275-
shard_id=shard_id,
276-
expert_id=expert_id,
277-
)
278-
break
279-
else:
280-
# Skip loading extra bias for GPTQ models.
281-
if name.endswith(".bias") and name not in params_dict:
282-
continue
283-
284-
name = maybe_remap_kv_scale_name(name, params_dict)
285-
if name is None:
286-
continue
287-
288-
# According to DeepSeek-V3 Technical Report, MTP modules
289-
# shares embedding layer. We only load the first weights.
290-
if (
291-
spec_layer != self.model.mtp_start_layer_idx
292-
and ".layers" not in name
293-
):
294-
continue
295-
296-
param = params_dict[name]
297-
weight_loader = getattr(
298-
param, "weight_loader", default_weight_loader
278+
# Special handling: when AITER fusion_shared_experts is enabled,
279+
# checkpoints may provide a single widened shared_experts tensor
280+
# without explicit expert indices
281+
# (e.g. ...mlp.shared_experts.gate_proj.weight).
282+
# For models with multiple shared experts, split that tensor
283+
# evenly into per-shared-expert slices and load them into
284+
# appended expert slots mlp.experts.{n_routed_experts + j}.*
285+
# accordingly.
286+
num_chunks = 1
287+
if is_fuse_shared_experts_layer:
288+
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
289+
# Determine split axis based on op type
290+
# gate/up: ColumnParallel → split along dim 0
291+
# down: RowParallel → split along dim 1
292+
split_dim = 1 if "down_proj.weight" in name else 0
293+
total = loaded_weight.shape[split_dim]
294+
assert total % num_chunks == 0, (
295+
f"Shared expert weight dim {total} "
296+
f"not divisible by num_chunks {num_chunks}"
299297
)
300-
weight_loader(param, loaded_weight)
301-
loaded_params.add(name)
298+
chunk_size = total // num_chunks
299+
300+
for j in range(num_chunks):
301+
chunk_name = name
302+
weight_to_load = loaded_weight
303+
304+
if is_fuse_shared_experts_layer:
305+
if split_dim == 0:
306+
weight_to_load = loaded_weight[
307+
j * chunk_size : (j + 1) * chunk_size, :
308+
]
309+
else:
310+
weight_to_load = loaded_weight[
311+
:, j * chunk_size : (j + 1) * chunk_size
312+
]
313+
# Synthesize an expert-style name so expert mapping
314+
# can route it
315+
chunk_name = name.replace(
316+
"mlp.shared_experts",
317+
f"mlp.experts.{self.config.n_routed_experts + j}",
318+
)
319+
320+
# Use expert_params_mapping to locate the destination
321+
# param and delegate to its expert-aware weight_loader
322+
# with expert_id.
323+
for mapping in expert_params_mapping:
324+
param_name, weight_name, expert_id, shard_id = mapping
325+
if weight_name not in chunk_name:
326+
continue
327+
328+
# Do not modify `name` since the loop may continue here
329+
# Instead, create a new variable
330+
name_mapped = chunk_name.replace(weight_name, param_name)
331+
332+
param = params_dict[name_mapped]
333+
# We should ask the weight loader to return success or
334+
# not here since otherwise we may skip experts with
335+
# other available replicas.
336+
weight_loader = typing.cast(
337+
Callable[..., bool], param.weight_loader
338+
)
339+
success = weight_loader(
340+
param,
341+
weight_to_load,
342+
name_mapped,
343+
shard_id=shard_id,
344+
expert_id=expert_id,
345+
return_success=True,
346+
)
347+
if success:
348+
if not is_fuse_shared_experts_layer:
349+
name = name_mapped
350+
else:
351+
loaded_params.add(name_mapped)
352+
break
353+
else:
354+
# Skip loading extra bias for GPTQ models.
355+
if name.endswith(".bias") and name not in params_dict:
356+
continue
357+
358+
name = maybe_remap_kv_scale_name(name, params_dict)
359+
if name is None:
360+
continue
361+
362+
# According to DeepSeek-V3 Technical Report, MTP modules
363+
# shares embedding layer. We only load the first weights.
364+
if (
365+
spec_layer != self.model.mtp_start_layer_idx
366+
and ".layers" not in name
367+
):
368+
continue
369+
370+
param = params_dict[name]
371+
weight_loader = getattr(
372+
param, "weight_loader", default_weight_loader
373+
)
374+
weight_loader(param, loaded_weight)
375+
if not is_fuse_shared_experts_layer:
376+
loaded_params.add(name)
302377
return loaded_params
303378

304379
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:

0 commit comments

Comments
 (0)