@@ -174,22 +174,10 @@ def free(self, request_id: str) -> None:
174174 self .num_cached_block .pop (request_id , None )
175175
176176 @abstractmethod
177- def get_num_common_prefix_blocks (self , request_id : str ,
178- num_running_requests : int ) -> int :
179- """
180- Get the number of common prefix blocks for all requests in the RUNNING
181- state.
182-
183- Args:
184- request_id: The request ID.
185- num_running_requests: The total number of requests in the RUNNING
186- state.
187-
188- Returns:
189- The number of common prefix blocks for all requests in the RUNNING
190- state.
191- """
192-
177+ def get_num_common_prefix_blocks (
178+ self , running_request_id : str , running_request_ids : list [str ],
179+ transfering_request_ids : list [str ]) -> int :
180+ """Get the number of common prefix blocks for all running requests."""
193181 raise NotImplementedError
194182
195183 @classmethod
@@ -292,15 +280,37 @@ def remove_skipped_blocks(self, request_id: str,
292280 # No need to remove blocks for full attention.
293281 pass
294282
295- def get_num_common_prefix_blocks (self , request_id : str ,
296- num_running_requests : int ) -> int :
297- blocks = self .req_to_blocks [request_id ]
283+ def get_num_common_prefix_blocks (
284+ self , running_request_id : str , running_request_ids : list [str ],
285+ transfering_request_ids : list [str ]) -> int :
286+ """Get common prefix blocks shared by all running and transferring
287+ requests."""
288+ if running_request_id not in self .req_to_blocks :
289+ return 0
290+
291+ reference_blocks = self .req_to_blocks [running_request_id ]
292+ total_requests = len (running_request_ids ) + len (
293+ transfering_request_ids )
294+
295+ transferring_blocks = [
296+ self .req_to_blocks [req_id ] for req_id in transfering_request_ids
297+ if req_id in self .req_to_blocks
298+ ]
299+
298300 num_common_blocks = 0
299- for block in blocks :
300- if block .ref_cnt == num_running_requests :
301- num_common_blocks += 1
302- else :
301+ for i , ref_block in enumerate (reference_blocks ):
302+
303+ if ref_block .ref_cnt < total_requests :
304+ break
305+
306+ transferring_has_block = sum (
307+ 1 for blocks in transferring_blocks if i < len (blocks )
308+ and blocks [i ].block_id == ref_block .block_id )
309+
310+ if transferring_has_block != len (transfering_request_ids ):
303311 break
312+ num_common_blocks += 1
313+
304314 return num_common_blocks
305315
306316
@@ -393,8 +403,12 @@ def remove_skipped_blocks(self, request_id: str,
393403 blocks [i ] = self ._null_block
394404 self .block_pool .free_blocks (removed_blocks )
395405
396- def get_num_common_prefix_blocks (self , request_id : str ,
397- num_running_requests : int ) -> int :
406+ def get_num_common_prefix_blocks (
407+ self ,
408+ running_request_id : str ,
409+ running_request_ids : list [str ],
410+ transfering_request_ids : list [str ],
411+ ) -> int :
398412 """
399413 NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
400414 So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -521,8 +535,12 @@ def remove_skipped_blocks(self, request_id: str,
521535 blocks [i ] = self ._null_block
522536 self .block_pool .free_blocks (removed_blocks )
523537
524- def get_num_common_prefix_blocks (self , request_id : str ,
525- num_running_requests : int ) -> int :
538+ def get_num_common_prefix_blocks (
539+ self ,
540+ running_request_id : str ,
541+ running_request_ids : list [str ],
542+ transfering_request_ids : list [str ],
543+ ) -> int :
526544 """
527545 cascade attention is not supported by chunked local attention.
528546 """
@@ -573,8 +591,12 @@ def remove_skipped_blocks(self, request_id: str,
573591 # (for which find_longest_cache_hit returns block_pool.null_block)
574592 pass
575593
576- def get_num_common_prefix_blocks (self , request_id : str ,
577- num_running_requests : int ) -> int :
594+ def get_num_common_prefix_blocks (
595+ self ,
596+ running_request_id : str ,
597+ running_request_ids : list [str ],
598+ transfering_request_ids : list [str ],
599+ ) -> int :
578600 """
579601 cascade attention is not supported by mamba
580602 """
@@ -618,8 +640,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
618640 # requests, so this method is not relevant.
619641 raise ValueError ("Should not be called as prefix caching is disabled." )
620642
621- def get_num_common_prefix_blocks (self , request_id : str ,
622- num_running_requests : int ) -> int :
643+ def get_num_common_prefix_blocks (
644+ self , running_request_id : str , running_request_ids : list [str ],
645+ transfering_request_ids : list [str ]) -> int :
623646 # Cross-attention blocks contain request-specific encoder states
624647 # and are not shared between different requests
625648 return 0
0 commit comments