-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
Add utility functions to enable pass-kv prefill and allgather decode. #26059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add utility functions to enable pass-kv prefill and allgather decode. #26059
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces utility functions and configuration changes to support context parallelism in vLLM, specifically for "pass-kv prefill" and "allgather decode" strategies. The changes span across configuration, distributed state management, worker components, and new attention utilities. While the core logic for context parallelism seems well-thought-out, I've identified a critical issue in the creation of expert parallel groups when context parallelism is enabled, which could lead to a runtime error. Additionally, there's a bug in one of the new tests.
vllm/distributed/parallel_state.py
Outdated
| group_ranks = (all_ranks.transpose(1, 2).reshape( | ||
| -1, data_parallel_size * tensor_model_parallel_size).unbind(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for creating expert parallel (EP) groups is incorrect when context parallelism (CP) is enabled (context_parallel_size > 1). The reshape operation on a non-contiguous tensor resulting from transpose(1, 2) will fail. The current all_ranks tensor has dimensions (ext_dp, dp, pp, cp, tp), and after transpose(1, 2), its shape becomes (ext_dp, pp, dp, cp, tp). Attempting to reshape this to (-1, dp*tp) is not a valid operation if cp > 1.
To correctly form EP groups of size dp*tp for each (pp, cp) pair, you should permute the dimensions to make the tensor contiguous with dp and tp as the last dimensions before reshaping.
| group_ranks = (all_ranks.transpose(1, 2).reshape( | |
| -1, data_parallel_size * tensor_model_parallel_size).unbind(0)) | |
| group_ranks = (all_ranks.permute(0, 2, 3, 1, 4).contiguous().view( | |
| -1, data_parallel_size * tensor_model_parallel_size).unbind(0)) |
| assert num_comp_local == [ | ||
| num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2] | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion for num_comp_local is incorrect. The second element in the expected list is [num_computed_tokens[1][-1] // 2], which is a list containing an integer. However, num_comp_local is a flat list of integers. This type mismatch will cause the test to fail.
| assert num_comp_local == [ | |
| num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2] | |
| ] | |
| assert num_comp_local == [ | |
| num_computed_tokens[0][-1] // 2, num_computed_tokens[1][-1] // 2 | |
| ] |
6d4170c to
1e98a21
Compare
c480912 to
e61f17d
Compare
e61f17d to
29cd86f
Compare
Signed-off-by: Qirui Yang <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
|
These conflicts are caused by our migration to |
Add utility functions to enable pass-kv prefill and allgather decode. For a more in-depth understanding of context parallelism in LLM inference, including partial attention, read the MLSys paper available at https://arxiv.org/pdf/2411.01783.
Purpose
Within the model, attention is the only component that has dependency on the sequence dimension, since each token must attend to all previous tokens in the same sequence. In contrast, FFN and element-wise operations are performed independently for each token. To implement efficient context parallelism in vLLM, the design needs to be aware of these dependencies to minimize synchronization overhead.

During the prefill phase, both the query (Q) and key-value (KV) tensors are sharded across GPUs. To ensure that each Q token can attend to all preceding KV tokens, it is necessary to exchange the relevant Q or KV shards among GPUs. To reduce synchronization overhead, data transfers are overlapped with partial attention computations, with the goal of fully hiding data transfer latency.
The choice between passing KV or Q shards depends on the relative sizes of the Q and KV tensors. For full prefill, passing KV shards is generally preferred, as the number of queries per KV head typically exceeds two in most models. Conversely, for chunked prefill, passing Q shards may be more efficient if the KV cache length is significantly greater than the number of Q tokens. The following figure shows an example of prefill with cp2.
For decode, KV are stored on CP ranks in a round-robin manner. The figure illustrates how decode is done when cp=2.

Test Plan
Unit tests and e2e tests will be added in following diffs.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.