Skip to content

Conversation

@sunfish2010
Copy link

@sunfish2010 sunfish2010 commented Oct 31, 2025

Purpose

Add MetaShufflingMoE as an Optional MoE backend for Llama4 models.
This PR adds only the TP support for bf16 models. It has basic skeleton code to prepare for future support of EP and fp8 models.

MetaShufflingMoE shows better TTFT performance at 1k and 2k input len

--request-rate 4 --random-input-len 1024 --random-output-len 128 --num-prompts 1000

Serving Benchmark Result Fused MoE MetaShuffling MoE cutlass grouped gemm   MetaShuffling MoE triton grouped gemm  
Successful requests 1000 1000   1000  
Failed requests 0 0   0  
Request rate configured (RPS) 4 4   4  
Benchmark duration (s) 251.21 251.1   251.15  
Total input tokens 1023000 1023000   1023000  
Total generated tokens 115584 115798   115586  
Request throughput (req/s) 3.98 3.98   3.98  
Output token throughput (tok/s) 460.1 461.16   460.22  
Peak output token throughput (tok/s) 988 1018   1048  
Peak concurrent requests 25 26   25  
Total Token throughput (tok/s) 4532.34 4535.18   4533.43  
---------------Time to First Token----------------          
Mean TTFT (ms) 78.29 69.2 -11.61% 71.68 -8.44%
Median TTFT (ms) 74.84 66.12 -11.65% 69.01 -7.79%
P99 TTFT (ms) 118.97 107.02 -10.04% 108.11 -9.13%
-----Time per Output Token (excl. 1st token)------          
Mean TPOT (ms) 14.66 14.22 -3.00% 14.2 -3.14%
Median TPOT (ms) 14.63 14.21 -2.87% 14.26 -2.53%
P99 TPOT (ms) 19.6 18.61 -5.05% 19.07 -2.70%
---------------Inter-token Latency----------------          
Mean ITL (ms) 14.66 14.22 -3.00% 14.18 -3.27%
Median ITL (ms) 11.89 11.86 -0.25% 11.81 -0.67%
P99 ITL (ms) 56.53 49.5 -12.44% 51.87 -8.24%

--request-rate 4 --random-input-len 2048 --random-output-len 128 --num-prompts 1000

Serving Benchmark Result Fused MoE MetaShuffling MoE cutlass grouped gemm   MetaShuffling MoE triton grouped gemm  
Successful requests 1000 1000   1000  
Failed requests 0 0   0  
Request rate configured (RPS) 4 4   4  
Benchmark duration (s) 251.4 251.38   251.32  
Total input tokens 2047000 2047000   2047000  
Total generated tokens 109075 108784 -0.27% 108163  
Request throughput (req/s) 3.98 3.98 0.00% 3.98  
Output token throughput (tok/s) 433.86 432.75 -0.26% 430.37  
Peak output token throughput (tok/s) 1018 1030 1.18% 1079  
Peak concurrent requests 26 24 -7.69% 22  
Total Token throughput (tok/s) 8576.12 8575.88 0.00% 8575.25  
---------------Time to First Token----------------          
Mean TTFT (ms) 84.43 82.5 -2.29% 79.47 -5.87%
Median TTFT (ms) 78.56 76.59 -2.51% 74.12 -5.65%
P99 TTFT (ms) 149.37 143.42 -3.98% 144.85 -3.03%
-----Time per Output Token (excl. 1st token)------          
Mean TPOT (ms) 14.64 14.69 0.34% 13.95 -4.71%
Median TPOT (ms) 14.49 14.47 -0.14% 13.78 -4.90%
P99 TPOT (ms) 21.56 20.61 -4.41% 19.34 -10.30%
---------------Inter-token Latency----------------          
Mean ITL (ms) 14.52 14.61 0.62% 13.88 -4.41%
Median ITL (ms) 11.73 11.82 0.77% 11.5 -1.96%
P99 ITL (ms) 55.74 54.77 -1.74% 51.98 -6.75%

Test Plan

Running lm_eval

HF_HUB_DISABLE_XET=1 with-proxy VLLM_USE_MODELSCOPE=False lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,trust_remote_code=True,tensor_parallel_size=8,max_model_len=32768" --tasks gsm8k --num_fewshot 8 --batch_size 128

Baseline results

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.9219|±  |0.0074|
|     |       |strict-match    |     8|exact_match|↑  |0.9098|±  |0.0079|

Test Result

Cutlass Backend

HF_HUB_DISABLE_XET=1 with-proxy VLLM_USE_MODELSCOPE=False VLLM_USE_META_SHUFFLING_MOE=1 lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,trust_remote_code=True,tensor_parallel_size=8,max_model_len=32768" --tasks gsm8k --num_fewshot 8 --batch_size 128

ResourceWarning: Enable tracemalloc to get the object allocation traceback
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.9204|±  |0.0075|
|     |       |strict-match    |     8|exact_match|↑  |0.9105|±  |0.0079|

Triton Backend

HF_HUB_DISABLE_XET=1 with-proxy VLLM_USE_MODELSCOPE=False VLLM_USE_META_SHUFFLING_MOE=1 VLLM_META_SHUFFLING_GEMM_BACKEND=triton lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,trust_remote_code=True,tensor_parallel_size=8,max_model_len=32768" --tasks gsm8k --num_fewshot 8 --batch_size 128

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.9204|±  |0.0075|
|     |       |strict-match    |     8|exact_match|↑  |0.9105|±  |0.0079|

Essential Elements of an Effective PR Description Checklist
  • [x ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [x ] The test plan, such as providing test command.
  • [x ] The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the llama Related to Llama models label Oct 31, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MetaShufflingMoE as a new, optional Mixture-of-Experts backend for Llama4 models, aimed at improving performance, particularly Time to First Token (TTFT). The changes include adding the necessary environment variables, the MetaShufflingMoE layer implementation, and its integration into the Llama4 model.

My review has identified a few critical issues in the new MetaShufflingMoE implementation that could lead to runtime crashes, especially in scenarios not covered by the current tests (e.g., when used without a shared expert). I've provided code suggestions to fix these issues. Overall, the feature is a valuable addition, and with these fixes, it should be more robust.

Comment on lines +77 to +84
if envs.VLLM_META_SHUFFLING_GEMM_BACKEND == "cutlass":
scatter_add_dense_tokens(
out_tokens=shared_out,
in_tokens=routed_out,
token_indices=route_info.token_indices,
valid_token_count=route_info.num_routed_tokens,
)
return shared_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When shared_out is None and the cutlass backend is used, scatter_add_dense_tokens is called with out_tokens=None, which will likely cause a crash. Subsequently, return shared_out would return None, leading to a None.view() call in the caller, which will also crash. This can happen when MetaShufflingMoE is used without a shared expert.

To fix this, we should create a zero tensor for shared_out if it's None. The shape can be derived from scores and routed_out.

Suggested change
if envs.VLLM_META_SHUFFLING_GEMM_BACKEND == "cutlass":
scatter_add_dense_tokens(
out_tokens=shared_out,
in_tokens=routed_out,
token_indices=route_info.token_indices,
valid_token_count=route_info.num_routed_tokens,
)
return shared_out
if envs.VLLM_META_SHUFFLING_GEMM_BACKEND == "cutlass":
if shared_out is None:
shared_out = torch.zeros(
scores.shape[0],
routed_out.shape[-1],
dtype=routed_out.dtype,
device=routed_out.device,
)
scatter_add_dense_tokens(
out_tokens=shared_out,
in_tokens=routed_out,
token_indices=route_info.token_indices,
valid_token_count=route_info.num_routed_tokens,
)
return shared_out

@mergify
Copy link

mergify bot commented Nov 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sunfish2010.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 3, 2025
@sunfish2010 sunfish2010 force-pushed the meta_shuffling_integration branch from 6acfc0e to 4609478 Compare November 3, 2025 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants