Skip to content

Commit 8f8659c

Browse files
gjc0824pisceskkk
andcommitted
[Refactor] all gather the accurate context lengths
Co-authored-by: gaojc <[email protected]> Co-authored-by: QiuChunshuo <[email protected]> Signed-off-by: gaojc <[email protected]> Signed-off-by: QiuChunshuo <[email protected]>
1 parent 1f8ffde commit 8f8659c

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195

196196
import torch
197197
from tqdm import tqdm
198+
import numpy as np
198199

199200
import vllm.envs as envs
200201
from vllm import _custom_ops as ops
@@ -845,15 +846,7 @@ def build(
845846
None,
846847
self.dcp_local_block_size,
847848
)
848-
# Note(qcs): The max local context lengths
849-
# padded to `dcp_local_block_size`.
850-
local_context_lens_cpu = (
851-
cdiv(
852-
context_lens_cpu,
853-
self.dcp_virtual_block_size,
854-
)
855-
* self.dcp_local_block_size
856-
)
849+
local_context_lens_cpu = local_context_lens_allrank[:, self.dcp_rank]
857850
# Note(hc): The above max_context_chunk already enforces
858851
# block_size alignment, DCP just need the block_size can
859852
# be divisible by dcp_world_size, because DCP use
@@ -989,7 +982,7 @@ def reorg_kvcache(
989982
local_context_lens_allrank: list[list[int]],
990983
sum_seq_len: int,
991984
max_seq_len: int,
992-
toks: int,
985+
local_context_lens_sum: list[int],
993986
) -> tuple[torch.Tensor, torch.Tensor]:
994987
"""
995988
reorg kvcache after cp local gather to tp layout for attn kernel.
@@ -1000,31 +993,35 @@ def reorg_kvcache(
1000993
local_context_lens_allrank: local context lengths on each CP rank.
1001994
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1002995
max_seq_len: the max value of cp_chunk_seq_lens_lst.
1003-
toks: the number of tokens for local gather cache.
996+
local_context_lens_sum: the total context tokens of all request
997+
on each CP rank.
1004998
"""
1005999
kv_c_segments = []
10061000
k_pe_segments = []
10071001
src_token_idx = 0
10081002
max_seq_len_check = 0
1003+
10091004
for local_chunk_seq_len, local_context_lens in zip(
10101005
local_chunk_seq_lens_lst, local_context_lens_allrank
10111006
):
10121007
cur_seq_len = 0
1008+
context_len_across_rank = 0
10131009
for rank, local_context_len in enumerate(local_context_lens):
10141010
if local_context_len != 0:
10151011
kv_c_segment = allgatered_kv_c_normed[
1016-
rank * toks + src_token_idx : rank * toks
1012+
context_len_across_rank + src_token_idx : context_len_across_rank
10171013
+ src_token_idx
10181014
+ local_context_len
10191015
]
10201016
k_pe_segment = allgatered_k_pe[
1021-
rank * toks + src_token_idx : rank * toks
1017+
context_len_across_rank + src_token_idx : context_len_across_rank
10221018
+ src_token_idx
10231019
+ local_context_len
10241020
]
10251021
kv_c_segments.append(kv_c_segment)
10261022
k_pe_segments.append(k_pe_segment)
10271023
cur_seq_len += local_context_len
1024+
context_len_across_rank += local_context_lens_sum[rank]
10281025
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
10291026
src_token_idx += local_chunk_seq_len
10301027
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
@@ -1613,11 +1610,21 @@ def _context_parallel_compute_prefill_context(
16131610
cur_allgather_workspace = workspace[
16141611
allgather_offset : allgather_offset * (1 + dcp_world_size)
16151612
]
1613+
local_context_lens_allrank = (
1614+
prefill_metadata.chunked_context.local_context_lens_allrank
1615+
)
1616+
local_context_lens_sum = np.sum(local_context_lens_allrank, axis=0).tolist()
16161617
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
1617-
cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
1618+
cur_allgather_kvcache = cur_allgather_workspace[: sum(local_context_lens_sum)]
1619+
16181620
cur_allgather_kvcache.copy_(
1619-
get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
1621+
get_dcp_group().all_gatherv(
1622+
local_gathered_kvcache,
1623+
dim=0,
1624+
sizes=local_context_lens_sum
1625+
)
16201626
)
1627+
16211628
assert (
16221629
cur_allgather_kvcache.shape[-1]
16231630
== self.kv_lora_rank + self.qk_rope_head_dim
@@ -1632,10 +1639,11 @@ def _context_parallel_compute_prefill_context(
16321639
local_chunk_seq_lens_lst=prefill_metadata.chunked_context.local_chunk_seq_lens[
16331640
i
16341641
],
1635-
local_context_lens_allrank=prefill_metadata.chunked_context.local_context_lens_allrank,
1642+
local_context_lens_allrank=
1643+
prefill_metadata.chunked_context.local_context_lens_allrank,
16361644
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
16371645
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
1638-
toks=toks,
1646+
local_context_lens_sum=local_context_lens_sum,
16391647
)
16401648

16411649
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(

0 commit comments

Comments
 (0)