Skip to content

Conversation

@zhyajie
Copy link

@zhyajie zhyajie commented Oct 30, 2025

Purpose

This PR aims to fix the loading errors for the DeepSeek MTP weights when VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled (which is the default setting).
The issue occurs during model loading where a KeyError is thrown for the parameter 'model.layers.61.mtp_block.mlp.shared_experts.down_proj.weight_scale_inv'.
Root Cause: The issue was introduced by PR vllm-project#24097 which added fused shared experts optimization for ROCm but did not properly adapt it for the DeepSeek MTP model architecture. This causes a KeyError during weight loading when the shared_experts parameter is missing for shared experts in MTP blocks.
The repair method refers to the changes made to vllm/model_executor/models/deepseek_v2.py in this PR: vllm-project#24097

Test Plan

The following tests validate DeepSeek models by collecting benchmark metrics and performning correctness tests through lm_eval.

vLLM server launch command:

AITER_ENABLE_VSKIP=0 \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--disable-log-requests \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--trust-remote-code \
--speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
--block-size 1

lm_eval command:

lm_eval --model local-completions --tasks gsm8k --model_args model=${model_name},base_url=http://localhost:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False

Test Result

berfor this PR,

(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] WorkerProc failed to start.
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] Traceback (most recent call last):
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/executor/multiproc_executor.py", line 601, in worker_main
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     worker = WorkerProc(*args, **kwargs)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/executor/multiproc_executor.py", line 456, in __init__
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.worker.load_model()
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/worker/gpu_worker.py", line 233, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 2895, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.drafter.load_model(self.model)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/spec_decode/eagle.py", line 930, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.model = get_model(
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/__init__.py", line 130, in get_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     return loader.load_model(vllm_config=vllm_config, model_config=model_config)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/base_loader.py", line 55, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.load_weights(model, model_config)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/default_loader.py", line 300, in load_weights
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     loaded_weights = model.load_weights(
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/models/deepseek_mtp.py", line 296, in load_weights
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     param = params_dict[name]
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] KeyError: 'model.layers.61.mtp_block.mlp.shared_experts.down_proj.weight_scale_inv'

after this PR, The service can start normally, the MTP weights are loaded properly, and the results of the gsm8k test and mtp model acceptance rate are as follows.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9530|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9522|±  |0.0059|
INFO 10-27 08:36:51 [metrics.py:100] SpecDecoding metrics: Mean acceptance length: 1.92, Accepted throughput: 76.00 tokens/s, Drafted throughput: 82.70 tokens/s, Accepted: 760 tokens, Drafted: 827 tokens, Per-position acceptance rate: 0.919, Avg Draft acceptance rate: 91.9%
NFO 10-27 08:37:01 [metrics.py:100] SpecDecoding metrics: Mean acceptance length: 1.94, Accepted throughput: 80.39 tokens/s, Drafted throughput: 85.59 tokens/s, Accepted: 804 tokens, Drafted: 856 tokens, Per-position acceptance rate: 0.939, Avg Draft acceptance rate: 93.9%

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@wuhuikx wuhuikx merged commit 41bc643 into ROCm:dev/perf Oct 30, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants