Skip to content

Conversation

@qiruiyangmeta
Copy link

@qiruiyangmeta qiruiyangmeta commented Oct 1, 2025

Add utility functions to enable pass-kv prefill and allgather decode. For a more in-depth understanding of context parallelism in LLM inference, including partial attention, read the MLSys paper available at https://arxiv.org/pdf/2411.01783.

Purpose

Within the model, attention is the only component that has dependency on the sequence dimension, since each token must attend to all previous tokens in the same sequence. In contrast, FFN and element-wise operations are performed independently for each token. To implement efficient context parallelism in vLLM, the design needs to be aware of these dependencies to minimize synchronization overhead.
During the prefill phase, both the query (Q) and key-value (KV) tensors are sharded across GPUs. To ensure that each Q token can attend to all preceding KV tokens, it is necessary to exchange the relevant Q or KV shards among GPUs. To reduce synchronization overhead, data transfers are overlapped with partial attention computations, with the goal of fully hiding data transfer latency.
The choice between passing KV or Q shards depends on the relative sizes of the Q and KV tensors. For full prefill, passing KV shards is generally preferred, as the number of queries per KV head typically exceeds two in most models. Conversely, for chunked prefill, passing Q shards may be more efficient if the KV cache length is significantly greater than the number of Q tokens. The following figure shows an example of prefill with cp2.
image

For decode, KV are stored on CP ranks in a round-robin manner. The figure illustrates how decode is done when cp=2.
image

Test Plan

Unit tests and e2e tests will be added in following diffs.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces utility functions and configuration changes to support context parallelism in vLLM, specifically for "pass-kv prefill" and "allgather decode" strategies. The changes span across configuration, distributed state management, worker components, and new attention utilities. While the core logic for context parallelism seems well-thought-out, I've identified a critical issue in the creation of expert parallel groups when context parallelism is enabled, which could lead to a runtime error. Additionally, there's a bug in one of the new tests.

Comment on lines 1265 to 1266
group_ranks = (all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for creating expert parallel (EP) groups is incorrect when context parallelism (CP) is enabled (context_parallel_size > 1). The reshape operation on a non-contiguous tensor resulting from transpose(1, 2) will fail. The current all_ranks tensor has dimensions (ext_dp, dp, pp, cp, tp), and after transpose(1, 2), its shape becomes (ext_dp, pp, dp, cp, tp). Attempting to reshape this to (-1, dp*tp) is not a valid operation if cp > 1.

To correctly form EP groups of size dp*tp for each (pp, cp) pair, you should permute the dimensions to make the tensor contiguous with dp and tp as the last dimensions before reshaping.

Suggested change
group_ranks = (all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0))
group_ranks = (all_ranks.permute(0, 2, 3, 1, 4).contiguous().view(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0))

Comment on lines +216 to +218
assert num_comp_local == [
num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2]
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion for num_comp_local is incorrect. The second element in the expected list is [num_computed_tokens[1][-1] // 2], which is a list containing an integer. However, num_comp_local is a flat list of integers. This type mismatch will cause the test to fail.

Suggested change
assert num_comp_local == [
num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2]
]
assert num_comp_local == [
num_computed_tokens[0][-1] // 2, num_computed_tokens[1][-1] // 2
]

Signed-off-by: Qirui Yang <[email protected]>
@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @qiruiyangmeta.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
@hmellor
Copy link
Member

hmellor commented Oct 8, 2025

These conflicts are caused by our migration to ruff. Please see https://vllm-dev.slack.com/archives/C07R5Q1Q2BB/p1759663228844749 which contains detailed instructions to make updating your branch as painless as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants