@@ -49,14 +49,20 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
4949
5050
5151@dataclass
52- class DeepseekV32IndexerPrefillMetadata :
52+ class DeepseekV32IndexerPrefillChunkMetadata :
5353 block_table : torch .Tensor
54- query_start_loc : torch .Tensor
55- max_query_len : int
5654 cu_seqlen_ks : torch .Tensor
5755 cu_seqlen_ke : torch .Tensor
5856 cu_seq_lens : torch .Tensor
5957 total_seq_lens : int
58+ token_start : int
59+ token_end : int
60+ num_reqs : int
61+
62+
63+ @dataclass
64+ class DeepseekV32IndexerPrefillMetadata :
65+ chunks : list [DeepseekV32IndexerPrefillChunkMetadata ]
6066
6167
6268@dataclass
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
98104
99105# TODO (zyongye) optimize this, this is now vibe coded
100106def kv_spans_from_batches (
101- start_seq_loc : torch .Tensor ,
102- seq_len_per_batch : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
107+ start_seq_loc : torch .Tensor , seq_len_per_batch : torch . Tensor ,
108+ device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
103109 """
104110 Args:
105111 start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -122,15 +128,14 @@ def kv_spans_from_batches(
122128 are the **last** `counts[i]` positions of that sequence.
123129 """
124130 q = start_seq_loc .to (dtype = torch .long )
125- L = seq_len_per_batch .to (dtype = torch .long , device = q . device )
131+ L = seq_len_per_batch .to (dtype = torch .long )
126132 assert q .dim () == 1 and L .dim () == 1
127133 assert q .numel () == L .numel () + 1 , "start_seq_loc must have length B+1"
128134
129135 # Selected tokens per batch and totals
130136 counts = q [1 :] - q [:- 1 ] # [B]
131137 N = int (q [- 1 ].item ()) # total selected tokens
132138 B = L .numel ()
133- device = L .device
134139
135140 if N == 0 :
136141 return (torch .empty (0 , dtype = torch .long , device = device ),
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
140145 kv_starts_per_batch = torch .cumsum (L , dim = 0 ) - L # [B]
141146
142147 # For each selected token, which batch does it belong to?
143- batch_id = torch .repeat_interleave (torch .arange (B , device = device ),
144- counts ) # [N]
148+ batch_id = torch .repeat_interleave (torch .arange (B ), counts ) # [N]
145149
146150 # Map batch KV start to each token
147151 start_tensor = kv_starts_per_batch [batch_id ] # [N]
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
151155 L_expand = torch .repeat_interleave (L , counts ) # [N]
152156 m_expand = torch .repeat_interleave (counts , counts ) # [N]
153157 # position within the selected block: 1..counts[b]
154- pos_within = (torch .arange (N , device = device , dtype = torch .long ) -
158+ pos_within = (torch .arange (N , dtype = torch .long ) -
155159 torch .repeat_interleave (q [:- 1 ], counts ) + 1 )
156160
157161 local_pos = L_expand - m_expand + pos_within # [N], 1-based
158162 end_location = start_tensor + local_pos # exclusive end
159163
160- return start_tensor .int (), end_location .int ()
164+ return start_tensor .int (). to ( device ) , end_location .int (). to ( device )
161165
162166
163167def get_max_prefill_buffer_size (vllm_config : VllmConfig ):
164168 max_model_len = vllm_config .model_config .max_model_len
165- # max_num_batched_tokens = \
166- # vllm_config.scheduler_config.max_num_batched_tokens
167- max_num_seq = vllm_config .scheduler_config .max_num_seqs
168- # NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
169- return max_model_len * max_num_seq
169+ # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
170+ # May be tuned later.
171+ return max_model_len * 2
172+
173+
174+ def split_prefill_chunks (seq_lens_cpu : torch .Tensor ,
175+ max_prefill_buffer_size : int ,
176+ reqs_start : int ) -> list [tuple [int , int ]]:
177+ """
178+ Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
179+ such that the total sequence length of each chunk is less than the
180+ maximum prefill buffer size.
181+
182+ Args:
183+ seq_lens_cpu: The sequence lengths of the prefill requests.
184+ max_prefill_buffer_size: The maximum prefill buffer size.
185+ reqs_start: The start index of the prefill requests.
186+
187+ Returns:
188+ A list of tuples of (reqs_start, reqs_end).
189+ """
190+ chunk_seq_ids = []
191+ total_seq_lens = 0
192+ for i in range (reqs_start , len (seq_lens_cpu )):
193+ cur_seq_len = seq_lens_cpu [i ].item ()
194+ assert cur_seq_len <= max_prefill_buffer_size
195+ total_seq_lens += cur_seq_len
196+ if total_seq_lens > max_prefill_buffer_size :
197+ chunk_seq_ids .append ((reqs_start , i ))
198+ reqs_start = i
199+ total_seq_lens = cur_seq_len
200+ if total_seq_lens > 0 :
201+ chunk_seq_ids .append ((reqs_start , len (seq_lens_cpu )))
202+ return chunk_seq_ids
170203
171204
172205class DeepseekV32IndexerMetadataBuilder (AttentionMetadataBuilder ):
@@ -201,6 +234,33 @@ def __init__(self, *args, **kwargs):
201234 dtype = torch .int32 ,
202235 device = self .device )
203236
237+ def build_one_prefill_chunk (self , reqs_start , reqs_end ,
238+ query_start_loc_cpu , seq_lens_cpu ,
239+ block_table ):
240+ prefill_query_start_loc = query_start_loc_cpu [
241+ reqs_start :reqs_end + 1 ] - query_start_loc_cpu [reqs_start ]
242+ cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
243+ prefill_query_start_loc , seq_lens_cpu [reqs_start :reqs_end ],
244+ self .device )
245+ token_start = query_start_loc_cpu [reqs_start ].item ()
246+ token_end = query_start_loc_cpu [reqs_end ].item ()
247+ total_seq_lens = seq_lens_cpu [reqs_start :reqs_end ].sum ()
248+ assert total_seq_lens <= self .max_prefill_buffer_size
249+ cu_seq_lens = torch .cat ([
250+ torch .zeros (1 , dtype = torch .int32 ),
251+ seq_lens_cpu [reqs_start :reqs_end ].cumsum (dim = 0 )
252+ ]).to (torch .int32 ).to (self .device )
253+ return DeepseekV32IndexerPrefillChunkMetadata (
254+ cu_seqlen_ks = cu_seqlen_ks ,
255+ cu_seqlen_ke = cu_seqlen_ke ,
256+ cu_seq_lens = cu_seq_lens ,
257+ total_seq_lens = total_seq_lens ,
258+ block_table = block_table [reqs_start :reqs_end ],
259+ token_start = token_start ,
260+ token_end = token_end ,
261+ num_reqs = reqs_end - reqs_start ,
262+ )
263+
204264 def build (self ,
205265 common_prefix_len : int ,
206266 common_attn_metadata : CommonAttentionMetadata ,
@@ -209,11 +269,7 @@ def build(self,
209269 num_reqs = common_attn_metadata .num_reqs
210270 num_tokens = common_attn_metadata .num_actual_tokens
211271
212- device = self .device
213- block_table_tensor = common_attn_metadata .block_table_tensor
214-
215- query_start_loc = common_attn_metadata .query_start_loc
216-
272+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
217273 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
218274 split_decodes_and_prefills (
219275 common_attn_metadata ,
@@ -224,27 +280,20 @@ def build(self,
224280
225281 prefill_metadata = None
226282 if num_prefills > 0 :
227- reqs_start = num_decodes
228- prefill_query_start_loc = query_start_loc [
229- reqs_start :] - query_start_loc [reqs_start ]
230- cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
231- prefill_query_start_loc ,
232- common_attn_metadata .seq_lens [reqs_start :])
233- total_seq_lens = common_attn_metadata .seq_lens [reqs_start :].sum ()
234- assert total_seq_lens < self .max_prefill_buffer_size
235- cu_seq_lens = torch .cat ([
236- torch .zeros (1 , dtype = torch .int32 , device = device ),
237- common_attn_metadata .seq_lens [reqs_start :].cumsum (dim = 0 )
238- ]).to (torch .int32 ).cuda ()
239- prefill_metadata = DeepseekV32IndexerPrefillMetadata (
240- block_table = block_table_tensor [reqs_start :, ...],
241- query_start_loc = prefill_query_start_loc ,
242- max_query_len = common_attn_metadata .max_query_len ,
243- cu_seqlen_ks = cu_seqlen_ks ,
244- cu_seqlen_ke = cu_seqlen_ke ,
245- cu_seq_lens = cu_seq_lens ,
246- total_seq_lens = total_seq_lens ,
283+ chunk_seq_ids = split_prefill_chunks (
284+ common_attn_metadata .seq_lens_cpu ,
285+ self .max_prefill_buffer_size ,
286+ num_decodes ,
247287 )
288+ chunks = [
289+ self .build_one_prefill_chunk (
290+ reqs_start , reqs_end , query_start_loc_cpu ,
291+ common_attn_metadata .seq_lens_cpu ,
292+ common_attn_metadata .block_table_tensor )
293+ for reqs_start , reqs_end in chunk_seq_ids
294+ ]
295+ prefill_metadata = DeepseekV32IndexerPrefillMetadata (
296+ chunks = chunks , )
248297
249298 decode_metadata = None
250299 if num_decodes > 0 :
0 commit comments