-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
Description
Motivation.
A follow-up to add prefill context parallel(PCP) after DCP (PR #23734) to optimize TTFT.
In DCP, the KV Cache is partitioned along the sequence dimension into several segments (interleave style), achieving context parallelism during the decode phase.
Similarly, PCP will split the entire request along the sequence dimension during the prefill phase and further partitions the KVCache along the sequence dimension.
Proposed Change.
Aside from the attention module, almost all other modules do not involve contextual dependencies and can naturally adapt to context parallelism.
In the attention module, we adopt the following strategy to implement PCP (chunked prefill is not considered for now), marked as red.
First, for the KV, we perform an AllGather op along the sequence dimension within the PCP group to obtain the complete KV values. Then, the kvcache is stored according to the slot_mapping.
For attention computation, since we have obtained the complete KV, we only need to carefully design the custom mask and perform normal attention.
During the decode phase, modules other than attention involve redundant computations (since num new tokens = 1) across PCP group.
In the attention module, we first execute the original DCP computation logic within the respective PCP group.
Then, before updating the attention output using lse, we perform an AllGather within the PCP group to obtain complete sequence information. And the subsequent steps remain unaffected.
Summary
- The creation of Prefill Context Parallel communication domains. Unlike DCP communication domains, which are subdivisions within the TP domain, PCP communication domains stand alongside DP, PP, and TP, and affect the total device count allocation.
- Modification of slot_mapping calculation. Building upon the virtual block concept introduced by DCP, PCP will continue to utilize virtual blocks within each PCP group. Each PCP group will be responsible for storing the portion of the KV Cache that corresponds to its assigned segment of the sequence.
- Modification of parameter calculations in the ModelRunner. The calculation of parameters such as sequence length and token count must be updated, as these values are now influenced and scaled by the PCP size.
- Modification of the attention backend computation logic. The core attention calculation requires changes to handle the distributed KV Cache across PCP groups. The specific algorithmic updates are detailed in the red markings within the provided schematic diagram.
Feedback Period.
We will complete the above modifications and submit a PR within 1-2 weeks
CC List.
@youkaichao @njhill @WoosukKwon @ruisearch42 @youzhedian
Any Other Things.
Here are some related issues and PRs:
DCP: PR #23734, PR #24864, PR #25438
PCP: RFC #22693, RFC #7519
Roadmap
- Basic feature; PR [Feature] Support Prefill Context Parallel (PCP) for GQA flashinfer #26864
- Make PCP of flashinfer compatible with DCP after PR [DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer #25438 is merged;
- Make block-level interleaved KV cache storage compatible after PR [DCP] Support dcp kv_cache interleave size > 1 #26696 is merged;
- PCP support for PIECEWISE CUDAGraph; 1598b45
- PCP support for chunked-prefill and prefix caching features;
Feature works (These items will be tackled in follow-up PRs; community contributions are warmly welcomed.):
- PCP support for MLA and other backends;
- PCP support for MTP;
- PCP support for CUDAFullGraph;
- Ring-CP style attention backend algorithm, ref RFC [RFC]: Support Context Parallelism with Fully Sharded KV Cache and Ring Attention #26133.