Skip to content

Commit 4b68c4a

Browse files
authored
[Core][Perf] Only invoke save_new_computed_blocks when computed blocks are not empty (#27799)
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent a8141fa commit 4b68c4a

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

vllm/v1/core/kv_cache_manager.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,12 @@ def allocate_slots(
306306
"Computed blocks should be empty when prefix caching is disabled"
307307
)
308308

309-
# Append the new computed blocks to the request blocks until now to
310-
# avoid the case where the new blocks cannot be allocated.
311-
self.coordinator.save_new_computed_blocks(
312-
request.request_id, new_computed_block_list
313-
)
309+
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
310+
# Append the new computed blocks to the request blocks until now to
311+
# avoid the case where the new blocks cannot be allocated.
312+
self.coordinator.save_new_computed_blocks(
313+
request.request_id, new_computed_block_list
314+
)
314315

315316
new_blocks = self.coordinator.allocate_new_blocks(
316317
request.request_id, num_tokens_need_slot, num_encoder_tokens

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
151151
num_tokens: The total number of tokens that need to be cached
152152
(including tokens that are already cached).
153153
"""
154-
num_cached_blocks = self.num_cached_block[request.request_id]
154+
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
155155
num_full_blocks = num_tokens // self.block_size
156156

157157
if num_cached_blocks >= num_full_blocks:

0 commit comments

Comments
 (0)