Skip to content

Conversation

@PerryZhang01
Copy link

@PerryZhang01 PerryZhang01 commented Oct 29, 2025

Purpose

This PR supports EPLB for the ROCm backend, achieving feature parity with the existing CUDA implementation. The implementation validated on DeepSeekR1.

Test Plan

we try to enable EPLB on DeepSeekR1 on MI355 with the following parameters.

server:
vllm serve $model_path \
--tensor-parallel-size 8 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
--no-enable-prefix-caching \
--disable-log-requests \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--gpu_memory_utilization 0.8 \
--block-size 1 \
--enable-expert-parallel \
--enable-eplb \
--num-redundant-experts 8 \
--eplb-log-balancedness \
--eplb-window-size 3000 \
--eplb-step-interval 1000
client:
python -m vllm.entrypoints.cli.main bench serve \
    --host localhost \
    --port 8000 \
    --model ${model_path} \
    --dataset-name random \
    --random-input-len 1024 \
    --random-output-len 1024 \
    --max-concurrency 64 \
    --num-prompts 128 \
    --seed 123 \
    --percentile-metrics ttft,tpot,itl,e2el \
    --ignore-eos

Test Result

Benchmark Result:
image
After using eplb, the balancedness metric increased from 0.55 to 0.65 with the random data. However, the avg_tokens is not enough in decode phase, so that the balancedness metric dropped to 0.3.

Besides, we use lm_eval to validate the accuracy of EPLB on gsm8k datasets, the result as below:

    |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
    |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
    |gsm8k|      3|flexible-extract|     5|exact_match||0.9492|±  |0.0060|
    |     |       |strict-match    |     5|exact_match||0.9492|±  |0.0060|

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Expert Parallelism Load Balancing (EPLB) on the ROCm backend, achieving feature parity with the CUDA implementation. The changes are logical and well-contained, primarily enabling the feature for ROCm and updating the relevant checks and method calls. I have one point of feedback regarding the removal of a tensor contiguity assertion, which could potentially lead to issues with distributed communication. I've suggested ensuring tensor contiguity to maintain correctness and performance.

@PerryZhang01
Copy link
Author

@abmfy could u please help review this PR

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much cleaner now, just one thing I'm not sure about

Comment on lines +1957 to +2087
assert all(
weight.is_contiguous()
for name, weight in weights
if not name.startswith("_shared_experts.")
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this change. @abmfy could this cause issues for other EPLB use cases?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the shared expert is currently not continuous, possibly because the later MR performed a stride operation, however, we think that EPLB only apply to routed experts(this function only returns routed experts), so we have canceled the contiguous check on shared expert, and shared expert may have other stride operations in the future and should not be asserted here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor @abmfy can someone confirm the change here?

Copy link
Author

@PerryZhang01 PerryZhang01 Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I have rebased the latest code and validated the accuracy of EPLB on gsm8k datasets again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor should be good here since EPLB is only on shared experts

@mergify
Copy link

mergify bot commented Nov 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @PerryZhang01.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2025
@PerryZhang01
Copy link
Author

@hmellor @abmfy @BowenBao could u help review this PR?

Copy link
Member

@abmfy abmfy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
Thanks for the contribution!

Comment on lines +1957 to +2087
assert all(
weight.is_contiguous()
for name, weight in weights
if not name.startswith("_shared_experts.")
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor should be good here since EPLB is only on shared experts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants