Skip to content

Commit fd1c710

Browse files
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention
* Replace ref_cnt-based common prefix detection with running request tracking * Update get_num_common_prefix_blocks() to accept running_request_ids set * Fix FullAttentionManager to count actual references from running requests * Prevent incorrect cascade attention when async KV offloading delays cleanup This resolves a bug where completed requests with pending async transfers still contributed to ref_cnt, causing incorrect cascade attention decisions. Signed-off-by: Ayush Satyam <[email protected]>
1 parent d3d649e commit fd1c710

File tree

4 files changed

+92
-75
lines changed

4 files changed

+92
-75
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,27 +138,29 @@ def free(self, request_id: str) -> None:
138138
for manager in self.single_type_managers:
139139
manager.free(request_id)
140140

141-
def get_num_common_prefix_blocks(self, request_id: str,
142-
num_running_requests: int) -> list[int]:
141+
def get_num_common_prefix_blocks(
142+
self, running_request_id: str, running_request_ids: list[str],
143+
transfering_request_ids: list[str]) -> list[int]:
143144
"""
144145
Get the number of common prefix blocks for all requests in the RUNNING
145-
state for each kv cache group.
146+
and TRANSFERING state for each kv cache group.
146147
147148
Args:
148-
request_id: The request ID.
149-
num_running_requests: The total number of requests in the RUNNING
150-
state.
149+
running_request_id: The request ID of the running request.
150+
running_request_ids: List of all request IDs in the RUNNING state.
151+
transfering_request_ids: List of request IDs in
152+
WAITING_FOR_REMOTE_KVS state.
151153
152154
Returns:
153155
list[int]: The number of common prefix blocks for all requests in
154156
the RUNNING state for each kv cache group.
155157
"""
156-
num_blocks_per_group = [
157-
manager.get_num_common_prefix_blocks(request_id,
158-
num_running_requests)
158+
return [
159+
manager.get_num_common_prefix_blocks(running_request_id,
160+
running_request_ids,
161+
transfering_request_ids)
159162
for manager in self.single_type_managers
160163
]
161-
return num_blocks_per_group
162164

163165
def remove_skipped_blocks(self, request_id: str,
164166
num_computed_tokens: int) -> None:
@@ -209,8 +211,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
209211
dcp_world_size=dcp_world_size)
210212
self.num_single_type_manager = len(self.single_type_managers)
211213

212-
def get_num_common_prefix_blocks(self, request_id: str,
213-
num_running_requests: int) -> list[int]:
214+
def get_num_common_prefix_blocks(
215+
self, running_request_id: str, running_request_ids: list[str],
216+
transfering_request_ids: list[str]) -> list[int]:
214217
return [0] * self.num_single_type_manager
215218

216219
def find_longest_cache_hit(

vllm/v1/core/kv_cache_manager.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.v1.core.kv_cache_utils import KVCacheBlock
1111
from vllm.v1.kv_cache_interface import KVCacheConfig
1212
from vllm.v1.metrics.stats import PrefixCacheStats
13-
from vllm.v1.request import Request, RequestStatus
13+
from vllm.v1.request import Request
1414

1515
logger = init_logger(__name__)
1616

@@ -339,46 +339,30 @@ def reset_prefix_cache(self) -> bool:
339339

340340
def get_num_common_prefix_blocks(
341341
self,
342-
request: Request,
343-
num_running_requests: int,
342+
running_request_id: str,
343+
running_request_ids: list[str],
344+
transfering_request_ids: list[str],
344345
) -> list[int]:
345346
"""Calculate the number of common prefix blocks shared by all requests
346-
in the RUNNING state for each kv cache group.
347-
348-
The function determines this by selecting any request and iterating
349-
through its blocks. A block is considered a common prefix block if its
350-
`ref_cnt` equals the total number of requests in the RUNNING state.
351-
352-
NOTE(woosuk): The number of requests in the RUNNING state is **greater
353-
than or equal to** the number of requests scheduled in the current step.
354-
This is because the RUNNING state only indicates that:
355-
1. The request has not yet finished, and
356-
2. The request holds its blocks unfreed.
357-
358-
While all scheduled requests must be in the RUNNING state, the inverse
359-
is not necessarily true. There may be RUNNING requests that are not
360-
scheduled in the current step.
347+
in the RUNNING state for each kv cache group. A block is considered a
348+
common prefix block if it is referenced by ALL currently running
349+
requests.
361350
362-
This can result in an edge case where the number of common prefix blocks
363-
is 0, even though all scheduled requests share a common prefix. This
364-
occurs because there may be unscheduled RUNNING requests that do not
365-
share the common prefix. Currently, this case cannot be easily detected,
366-
so the function returns 0 in such cases.
351+
This approach correctly handles async KV offloading scenarios where
352+
completed requests may still hold block references while no longer
353+
being in the RUNNING state.
367354
368355
Args:
369-
request: Any request in the RUNNING state, used to identify the
370-
common prefix blocks.
371-
num_running_requests: The total number of requests in the RUNNING
372-
state. This can be different from the number of scheduled
373-
requests in the current step.
356+
running_request_id: The request ID of the running request.
357+
running_request_ids: List of all request IDs in the RUNNING state.
358+
transfering_request_ids: List of request IDs in transfer state.
374359
375360
Returns:
376361
list[int]: The number of common prefix blocks for each kv cache
377362
group.
378363
"""
379-
assert request.status == RequestStatus.RUNNING
380364
return self.coordinator.get_num_common_prefix_blocks(
381-
request.request_id, num_running_requests)
365+
running_request_id, running_request_ids, transfering_request_ids)
382366

383367
def take_events(self) -> list[KVCacheEvent]:
384368
"""Take the KV cache events from the block pool.
@@ -404,4 +388,4 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
404388
def create_empty_block_list(self) -> KVCacheBlocks:
405389
"""Creates a new KVCacheBlocks instance with no blocks."""
406390
return KVCacheBlocks(tuple([]
407-
for _ in range(self.num_kv_cache_groups)))
391+
for _ in range(self.num_kv_cache_groups)))

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,9 +560,16 @@ def schedule(self) -> SchedulerOutput:
560560
self.kv_cache_config.kv_cache_groups)
561561
if self.running:
562562
any_request = self.running[0]
563+
running_request_ids = {req.request_id for req in self.running}
564+
565+
transferring_request_ids = [
566+
req_id for req_id, request in self.requests.items()
567+
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
568+
]
563569
num_common_prefix_blocks = (
564570
self.kv_cache_manager.get_num_common_prefix_blocks(
565-
any_request, len(self.running)))
571+
any_request.request_id, list(running_request_ids),
572+
transferring_request_ids))
566573

567574
# Construct the scheduler output.
568575
new_reqs_data = [

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,10 @@ def free(self, request_id: str) -> None:
174174
self.num_cached_block.pop(request_id, None)
175175

176176
@abstractmethod
177-
def get_num_common_prefix_blocks(self, request_id: str,
178-
num_running_requests: int) -> int:
179-
"""
180-
Get the number of common prefix blocks for all requests in the RUNNING
181-
state.
182-
183-
Args:
184-
request_id: The request ID.
185-
num_running_requests: The total number of requests in the RUNNING
186-
state.
187-
188-
Returns:
189-
The number of common prefix blocks for all requests in the RUNNING
190-
state.
191-
"""
192-
177+
def get_num_common_prefix_blocks(
178+
self, running_request_id: str, running_request_ids: list[str],
179+
transfering_request_ids: list[str]) -> int:
180+
"""Get the number of common prefix blocks for all running requests."""
193181
raise NotImplementedError
194182

195183
@classmethod
@@ -292,15 +280,37 @@ def remove_skipped_blocks(self, request_id: str,
292280
# No need to remove blocks for full attention.
293281
pass
294282

295-
def get_num_common_prefix_blocks(self, request_id: str,
296-
num_running_requests: int) -> int:
297-
blocks = self.req_to_blocks[request_id]
283+
def get_num_common_prefix_blocks(
284+
self, running_request_id: str, running_request_ids: list[str],
285+
transfering_request_ids: list[str]) -> int:
286+
"""Get common prefix blocks shared by all running and transferring
287+
requests."""
288+
if running_request_id not in self.req_to_blocks:
289+
return 0
290+
291+
reference_blocks = self.req_to_blocks[running_request_id]
292+
total_requests = len(running_request_ids) + len(
293+
transfering_request_ids)
294+
295+
transferring_blocks = [
296+
self.req_to_blocks[req_id] for req_id in transfering_request_ids
297+
if req_id in self.req_to_blocks
298+
]
299+
298300
num_common_blocks = 0
299-
for block in blocks:
300-
if block.ref_cnt == num_running_requests:
301-
num_common_blocks += 1
302-
else:
301+
for i, ref_block in enumerate(reference_blocks):
302+
303+
if ref_block.ref_cnt < total_requests:
304+
break
305+
306+
transferring_has_block = sum(
307+
1 for blocks in transferring_blocks if i < len(blocks)
308+
and blocks[i].block_id == ref_block.block_id)
309+
310+
if transferring_has_block != len(transfering_request_ids):
303311
break
312+
num_common_blocks += 1
313+
304314
return num_common_blocks
305315

306316

@@ -393,8 +403,12 @@ def remove_skipped_blocks(self, request_id: str,
393403
blocks[i] = self._null_block
394404
self.block_pool.free_blocks(removed_blocks)
395405

396-
def get_num_common_prefix_blocks(self, request_id: str,
397-
num_running_requests: int) -> int:
406+
def get_num_common_prefix_blocks(
407+
self,
408+
running_request_id: str,
409+
running_request_ids: list[str],
410+
transfering_request_ids: list[str],
411+
) -> int:
398412
"""
399413
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
400414
So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -521,8 +535,12 @@ def remove_skipped_blocks(self, request_id: str,
521535
blocks[i] = self._null_block
522536
self.block_pool.free_blocks(removed_blocks)
523537

524-
def get_num_common_prefix_blocks(self, request_id: str,
525-
num_running_requests: int) -> int:
538+
def get_num_common_prefix_blocks(
539+
self,
540+
running_request_id: str,
541+
running_request_ids: list[str],
542+
transfering_request_ids: list[str],
543+
) -> int:
526544
"""
527545
cascade attention is not supported by chunked local attention.
528546
"""
@@ -573,8 +591,12 @@ def remove_skipped_blocks(self, request_id: str,
573591
# (for which find_longest_cache_hit returns block_pool.null_block)
574592
pass
575593

576-
def get_num_common_prefix_blocks(self, request_id: str,
577-
num_running_requests: int) -> int:
594+
def get_num_common_prefix_blocks(
595+
self,
596+
running_request_id: str,
597+
running_request_ids: list[str],
598+
transfering_request_ids: list[str],
599+
) -> int:
578600
"""
579601
cascade attention is not supported by mamba
580602
"""
@@ -618,8 +640,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
618640
# requests, so this method is not relevant.
619641
raise ValueError("Should not be called as prefix caching is disabled.")
620642

621-
def get_num_common_prefix_blocks(self, request_id: str,
622-
num_running_requests: int) -> int:
643+
def get_num_common_prefix_blocks(
644+
self, running_request_id: str, running_request_ids: list[str],
645+
transfering_request_ids: list[str]) -> int:
623646
# Cross-attention blocks contain request-specific encoder states
624647
# and are not shared between different requests
625648
return 0

0 commit comments

Comments
 (0)