195195
196196import torch
197197from tqdm import tqdm
198+ import numpy as np
198199
199200import vllm .envs as envs
200201from vllm import _custom_ops as ops
@@ -845,15 +846,7 @@ def build(
845846 None ,
846847 self .dcp_local_block_size ,
847848 )
848- # Note(qcs): The max local context lengths
849- # padded to `dcp_local_block_size`.
850- local_context_lens_cpu = (
851- cdiv (
852- context_lens_cpu ,
853- self .dcp_virtual_block_size ,
854- )
855- * self .dcp_local_block_size
856- )
849+ local_context_lens_cpu = local_context_lens_allrank [:, self .dcp_rank ]
857850 # Note(hc): The above max_context_chunk already enforces
858851 # block_size alignment, DCP just need the block_size can
859852 # be divisible by dcp_world_size, because DCP use
@@ -989,7 +982,7 @@ def reorg_kvcache(
989982 local_context_lens_allrank : list [list [int ]],
990983 sum_seq_len : int ,
991984 max_seq_len : int ,
992- toks : int ,
985+ local_context_lens_sum : list [ int ] ,
993986) -> tuple [torch .Tensor , torch .Tensor ]:
994987 """
995988 reorg kvcache after cp local gather to tp layout for attn kernel.
@@ -1000,31 +993,35 @@ def reorg_kvcache(
1000993 local_context_lens_allrank: local context lengths on each CP rank.
1001994 sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1002995 max_seq_len: the max value of cp_chunk_seq_lens_lst.
1003- toks: the number of tokens for local gather cache.
996+ local_context_lens_sum: the total context tokens of all request
997+ on each CP rank.
1004998 """
1005999 kv_c_segments = []
10061000 k_pe_segments = []
10071001 src_token_idx = 0
10081002 max_seq_len_check = 0
1003+
10091004 for local_chunk_seq_len , local_context_lens in zip (
10101005 local_chunk_seq_lens_lst , local_context_lens_allrank
10111006 ):
10121007 cur_seq_len = 0
1008+ context_len_across_rank = 0
10131009 for rank , local_context_len in enumerate (local_context_lens ):
10141010 if local_context_len != 0 :
10151011 kv_c_segment = allgatered_kv_c_normed [
1016- rank * toks + src_token_idx : rank * toks
1012+ context_len_across_rank + src_token_idx : context_len_across_rank
10171013 + src_token_idx
10181014 + local_context_len
10191015 ]
10201016 k_pe_segment = allgatered_k_pe [
1021- rank * toks + src_token_idx : rank * toks
1017+ context_len_across_rank + src_token_idx : context_len_across_rank
10221018 + src_token_idx
10231019 + local_context_len
10241020 ]
10251021 kv_c_segments .append (kv_c_segment )
10261022 k_pe_segments .append (k_pe_segment )
10271023 cur_seq_len += local_context_len
1024+ context_len_across_rank += local_context_lens_sum [rank ]
10281025 max_seq_len_check = max (max_seq_len_check , cur_seq_len )
10291026 src_token_idx += local_chunk_seq_len
10301027 reorganized_kv_c_normed = torch .cat (kv_c_segments , dim = 0 )
@@ -1613,11 +1610,21 @@ def _context_parallel_compute_prefill_context(
16131610 cur_allgather_workspace = workspace [
16141611 allgather_offset : allgather_offset * (1 + dcp_world_size )
16151612 ]
1613+ local_context_lens_allrank = (
1614+ prefill_metadata .chunked_context .local_context_lens_allrank
1615+ )
1616+ local_context_lens_sum = np .sum (local_context_lens_allrank , axis = 0 ).tolist ()
16161617 assert toks * dcp_world_size <= cur_allgather_workspace .shape [0 ]
1617- cur_allgather_kvcache = cur_allgather_workspace [: toks * dcp_world_size ]
1618+ cur_allgather_kvcache = cur_allgather_workspace [: sum (local_context_lens_sum )]
1619+
16181620 cur_allgather_kvcache .copy_ (
1619- get_dcp_group ().all_gather (local_gathered_kvcache , dim = 0 )
1621+ get_dcp_group ().all_gatherv (
1622+ local_gathered_kvcache ,
1623+ dim = 0 ,
1624+ sizes = local_context_lens_sum
1625+ )
16201626 )
1627+
16211628 assert (
16221629 cur_allgather_kvcache .shape [- 1 ]
16231630 == self .kv_lora_rank + self .qk_rope_head_dim
@@ -1632,10 +1639,11 @@ def _context_parallel_compute_prefill_context(
16321639 local_chunk_seq_lens_lst = prefill_metadata .chunked_context .local_chunk_seq_lens [
16331640 i
16341641 ],
1635- local_context_lens_allrank = prefill_metadata .chunked_context .local_context_lens_allrank ,
1642+ local_context_lens_allrank =
1643+ prefill_metadata .chunked_context .local_context_lens_allrank ,
16361644 sum_seq_len = prefill_metadata .chunked_context .cu_seq_lens_lst [i ][- 1 ],
16371645 max_seq_len = prefill_metadata .chunked_context .max_seq_lens [i ],
1638- toks = toks ,
1646+ local_context_lens_sum = local_context_lens_sum ,
16391647 )
16401648
16411649 kv_nope = self .kv_b_proj (kv_c_normed )[0 ].view (
0 commit comments