From c6a2d255afb1cab2c419692c3d4e65593736cab2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 1 Apr 2025 05:33:08 -0700 Subject: [PATCH 01/34] copy manager code Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_kv_cache_manager.py | 376 ++++++++++++++++++++++++ 1 file changed, 376 insertions(+) create mode 100644 vllm/v1/core/hybrid_kv_cache_manager.py diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py new file mode 100644 index 000000000000..c0f7715209d1 --- /dev/null +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from collections.abc import Iterable +from typing import Optional + +from vllm.logger import init_logger +from vllm.utils import cdiv, sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + hash_request_tokens) +from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class KVCacheManager: + + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + enable_caching: bool = True, + caching_hash_algo: str = "builtin", + num_preallocate_tokens: int = 64, + log_stats: bool = False, + ) -> None: + assert len(kv_cache_config.kv_cache_groups) == 1, ( + "KVCacheManager does not support hybrid models with more than 1 " + "kv cache group") + kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + self.block_size = kv_cache_spec.block_size + self.num_gpu_blocks = kv_cache_config.num_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) + + self.enable_caching = enable_caching + self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + # FIXME: make prefix cache stats conditional on log_stats + self.log_stats = log_stats + # NOTE(woosuk): To avoid frequent block allocation, we preallocate some + # blocks for each request. For example, when a request reaches the end + # of its block table, we preallocate N blocks in advance. This way, we + # reduce the overhead of updating free_block_ids and ref_cnts for each + # request every step (at the cost of some memory waste). + # NOTE(woosuk): This is different from the "lookahead" slots since this + # does not guarantee that the request always has N empty blocks. After + # the request gets N empty blocks, it starts to use the blocks without + # further allocation. When it uses up all the N empty blocks, it gets + # N new empty blocks. + self.num_preallocate_tokens = num_preallocate_tokens + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, + self.block_size) + + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) + + self.specialized_manager = get_specialized_manager( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + ) + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) + + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: defaultdict[ + str, list[BlockHashType]] = defaultdict(list) + + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + self.prefix_cache_stats = PrefixCacheStats() + + @property + def usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return self.block_pool.get_usage() + + def make_prefix_cache_stats(self) -> PrefixCacheStats: + """Get (and reset) the prefix cache stats. + + Returns: + The current prefix caching stats. + """ + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + + def get_computed_blocks( + self, request: Request) -> tuple[list[KVCacheBlock], int]: + """Get the computed (cached) blocks for the request. + Note that the computed blocks must be full. + + Args: + request: The request to get the computed blocks. + + Returns: + A tuple containing: + - A list of blocks that are computed for the request. + - The number of computed tokens. + """ + if not self.enable_caching: + # Prefix caching is disabled. + return [], 0 + + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.caching_hash_fn, + self.block_size, request) + self.req_to_block_hashes[request.request_id] = block_hashes + + self.prefix_cache_stats.requests += 1 + if request.sampling_params.prompt_logprobs is None: + if len(block_hashes) * self.block_size == request.num_tokens: + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, remove the last + # block hash from the block_hashes for find_longest_cache_hit + # This limitation can potentially be removed in the future to + # slightly improve the performance. + last_block_hash = block_hashes.pop() + else: + last_block_hash = None + + computed_blocks = ( + self.specialized_manager.find_longest_cache_hit(block_hashes)) + + if last_block_hash is not None: + # Add back the last block hash if it was removed. + block_hashes.append(last_block_hash) + + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens + else: + # Skip cache hits for prompt logprobs + return [], 0 + + def allocate_slots( + self, + request: Request, + num_tokens: int, + new_computed_blocks: Optional[list[KVCacheBlock]] = None + ) -> Optional[list[KVCacheBlock]]: + """Add slots for a request with new tokens to append. + + Args: + request: The request to allocate slots. + num_tokens: The number of tokens to allocate. Note that this does + not include the tokens that have already been computed. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. + + Returns: + A list of new allocated blocks. + """ + if num_tokens == 0: + raise ValueError("num_tokens must be greater than 0") + + new_computed_blocks = new_computed_blocks or [] + + req_blocks = self.req_to_blocks[request.request_id] + + # Free the blocks that are skipped during the attention computation + # (e.g., tokens outside the sliding window). + # We can do this even if we cannot schedule this request due to + # insufficient free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + removed_blocks = self.specialized_manager.remove_skipped_blocks( + req_blocks, request.num_computed_tokens) + self.block_pool.free_blocks(removed_blocks) + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, + self.block_size) + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks): + # Cannot allocate new blocks + return None + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self.block_pool.touch(new_computed_blocks) + else: + assert not new_computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.block_pool.get_num_free_blocks(), + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + # Use `new_computed_blocks` for a new request, and `num_cached_block` + # for a running request. + num_cached_blocks = self.num_cached_block.get(request.request_id, + len(new_computed_blocks)) + # Speculated tokens might be rejected in the future, so we does + # not cache any speculated tokens. We only cache blocks with + # generated (accepted) tokens. + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=self.req_to_block_hashes[request.request_id], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[ + request.request_id] = num_full_blocks_after_append + return new_blocks + + def free(self, request: Request) -> None: + """Free the blocks allocated for the request. + When caching is enabled, we free the blocks in reverse order so that + the tail blocks are evicted first. + + Args: + request: The request to free the blocks. + """ + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) + ordered_blocks: Iterable[KVCacheBlock] = blocks + if self.enable_caching: + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request.request_id, None) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + if self.block_pool.reset_prefix_cache(): + self.prefix_cache_stats.reset = True + return True + return False + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> int: + """Calculate the number of common prefix blocks shared by all requests + in the RUNNING state. + + The function determines this by selecting any request and iterating + through its blocks. A block is considered a common prefix block if its + `ref_cnt` equals the total number of requests in the RUNNING state. + + NOTE(woosuk): The number of requests in the RUNNING state is **greater + than or equal to** the number of requests scheduled in the current step. + This is because the RUNNING state only indicates that: + 1. The request has not yet finished, and + 2. The request holds its blocks unfreed. + + While all scheduled requests must be in the RUNNING state, the inverse + is not necessarily true. There may be RUNNING requests that are not + scheduled in the current step. + + This can result in an edge case where the number of common prefix blocks + is 0, even though all scheduled requests share a common prefix. This + occurs because there may be unscheduled RUNNING requests that do not + share the common prefix. Currently, this case cannot be easily detected, + so the function returns 0 in such cases. + + Args: + request: Any request in the RUNNING state, used to identify the + common prefix blocks. + num_running_requests: The total number of requests in the RUNNING + state. This can be different from the number of scheduled + requests in the current step. + + Returns: + int: The number of common prefix blocks. + """ + assert request.status == RequestStatus.RUNNING + blocks = self.req_to_blocks[request.request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) From 4b27c8214cc382216b1d2fdd580decca0f05d11a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 00:39:58 -0700 Subject: [PATCH 02/34] save Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 8 + vllm/config.py | 4 + vllm/engine/arg_utils.py | 8 + vllm/forward_context.py | 11 +- vllm/v1/attention/backends/flash_attn.py | 8 +- vllm/v1/core/block_pool.py | 6 +- vllm/v1/core/hybrid_kv_cache_manager.py | 406 ++++++++++++++--------- vllm/v1/core/kv_cache_manager.py | 38 ++- vllm/v1/core/kv_cache_utils.py | 125 ++++++- vllm/v1/core/sched/output.py | 10 +- vllm/v1/core/sched/scheduler.py | 35 +- vllm/v1/core/specialized_manager.py | 100 ++++-- vllm/v1/kv_cache_interface.py | 91 ++++- vllm/v1/worker/block_table.py | 83 +++++ vllm/v1/worker/gpu_input_batch.py | 13 +- vllm/v1/worker/gpu_model_runner.py | 211 +++++++----- 16 files changed, 842 insertions(+), 315 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dbf4723ee1bd..1043654ef978 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -206,6 +206,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + # if isinstance(attn_metadata, dict): + # attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -222,6 +224,8 @@ def forward( if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -337,6 +341,8 @@ def unified_attention( ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) @@ -369,6 +375,8 @@ def unified_attention_with_output( ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/config.py b/vllm/config.py index c82c9763ccdc..038364fbeae7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1119,6 +1119,8 @@ class CacheConfig: sliding_window: Sliding window size for the KV cache. enable_prefix_caching: Whether to enable prefix caching. cpu_offload_gb: Size of the CPU offload buffer in GiB. + disable_hybrid_allocator: Whether to disable the hybrid allocator (Only + affects v1). """ def compute_hash(self) -> str: @@ -1153,6 +1155,7 @@ def __init__( prefix_caching_hash_algo: str = "builtin", cpu_offload_gb: float = 0, calculate_kv_scales: Optional[bool] = None, + disable_hybrid_allocator: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -1165,6 +1168,7 @@ def __init__( self.prefix_caching_hash_algo = prefix_caching_hash_algo self.cpu_offload_gb = cpu_offload_gb self.calculate_kv_scales = calculate_kv_scales + self.disable_hybrid_allocator = disable_hybrid_allocator self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ecdcab50e452..5dd116c4cab5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -205,6 +205,7 @@ class EngineArgs: model_impl: str = "auto" calculate_kv_scales: Optional[bool] = None + disable_hybrid_allocator: bool = False additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None @@ -964,6 +965,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'be loaded from the model checkpoint if available. ' 'Otherwise, the scales will default to 1.0.') + parser.add_argument( + "--disable-hybrid-allocator", + action="store_true", + default=False, + help="Disable the hybrid allocator. This only affects v1.") + parser.add_argument( "--additional-config", type=json.loads, @@ -1173,6 +1180,7 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + disable_hybrid_allocator=self.disable_hybrid_allocator, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac..87489ac4ca32 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -34,8 +34,13 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to + AttentionMetadata of that layer + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 92e4ffd0371a..47af053a6837 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,6 +11,7 @@ AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.core.block.block_table import BlockTable from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -107,14 +108,13 @@ def reorder_batch(self, input_batch: "InputBatch", return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, block_table: BlockTable): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( self.runner.device, non_blocking=True) seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, non_blocking=True) - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + block_table_tensor = block_table.get_device_tensor()[:num_reqs] slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() @@ -142,7 +142,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=block_table, + block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 43f30f7103c7..3707c00b5fbe 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -82,6 +82,7 @@ def cache_full_blocks( num_full_blocks: int, block_size: int, hash_fn: Callable, + kv_cache_group_id: int = -1, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash @@ -101,6 +102,8 @@ def cache_full_blocks( be cached after this function. block_size: Number of tokens in each block. hash_fn: The hash function to use for block hashes. + kv_cache_group_id: The id of the kv cache group. -1 means no kv + cache group. """ if num_cached_blocks == num_full_blocks: return @@ -143,7 +146,8 @@ def cache_full_blocks( # we reach to this branch only when the block is completed with # generated tokens, we only need to consider the last mm input. extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) + request, start_token_idx, end_token_idx, -1, + kv_cache_group_id) # Compute the hash of the current block. block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index c0f7715209d1..9dbf1bfd3606 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.core.specialized_manager import get_specialized_manager @@ -17,7 +18,12 @@ logger = init_logger(__name__) -class KVCacheManager: +class HybridKVCacheManager: + """ + The HybridKVCacheManager for models with multiple KV cache types + (e.g., Gemma-2) and thus multiple kv cache groups (Refer to class + `KVCacheConfig` for the meaning of kv cache groups). + """ def __init__( self, @@ -28,15 +34,15 @@ def __init__( num_preallocate_tokens: int = 64, log_stats: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size + # TODO: adjust the name for item in one group, list of items in all + # groups, and reduced item for all groups. + self.kv_cache_config = kv_cache_config self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) - + self.max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash # FIXME: make prefix cache stats conditional on log_stats @@ -51,57 +57,52 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all layers, which will result in different + # preallocated tokens for different layers if their block sizes are + # different. + self.num_preallocate_blocks = cdiv( + num_preallocate_tokens, + max(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups)) self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, - ) + self.specialized_managers = [ + get_specialized_manager( + kv_cache_spec=g.kv_cache_spec, + block_pool=self.block_pool, + ) for g in kv_cache_config.kv_cache_groups + ] + + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[ + str, list[list[KVCacheBlock]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + str, list[list[BlockHashType]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) - # {req_id: The number of cached blocks for this given request} + # {req_id: The number of cached blocks for each kv cache group} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: dict[str, int] = {} + self.num_cached_block: dict[str, list[int]] = {} self.prefix_cache_stats = PrefixCacheStats() - @property - def usage(self) -> float: - """Get the KV cache usage. - - Returns: - The KV cache usage (between 0.0 and 1.0). - """ - return self.block_pool.get_usage() - - def make_prefix_cache_stats(self) -> PrefixCacheStats: - """Get (and reset) the prefix cache stats. - - Returns: - The current prefix caching stats. - """ - stats = self.prefix_cache_stats - self.prefix_cache_stats = PrefixCacheStats() - return stats + usage = KVCacheManager.usage + make_prefix_cache_stats = KVCacheManager.make_prefix_cache_stats def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + self, request: Request) -> tuple[list[list[KVCacheBlock]], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -110,50 +111,49 @@ def get_computed_blocks( Returns: A tuple containing: - - A list of blocks that are computed for the request. + - A list of blocks that are computed for each kv cache group. - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return [[] for _ in range(self.num_kv_cache_groups)], 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) + block_hashes = [ + hash_request_tokens(self.caching_hash_fn, + g.kv_cache_spec.block_size, request, i) + for i, g in enumerate(self.kv_cache_config.kv_cache_groups) + ] self.req_to_block_hashes[request.request_id] = block_hashes self.prefix_cache_stats.requests += 1 if request.sampling_params.prompt_logprobs is None: - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - - computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) - - if last_block_hash is not None: - # Add back the last block hash if it was removed. - block_hashes.append(last_block_hash) + # TODO: Fix last block problem + # if len(block_hashes) * self.block_size == request.num_tokens: + # # When prompt length is divisible by the block size and all + # # blocks are cached, we need to recompute the last token. This + # # have to be achieved by re-computing an entire block because + # # allocate_slots() assumes num_computed_tokens is always a + # # multiple of the block size. To achieve this, remove the last + # # block hash from the block_hashes for find_longest_cache_hit + # # This limitation can potentially be removed in the future to + # # slightly improve the performance. + # last_block_hash = block_hashes.pop() + # else: + # last_block_hash = None + + computed_blocks, num_computed_tokens = self.find_longest_cache_hit( + request.request_id, block_hashes) + + # if last_block_hash is not None: + # # Add back the last block hash if it was removed. + # block_hashes.append(last_block_hash) self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens else: # Skip cache hits for prompt logprobs @@ -163,7 +163,8 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None + new_computed_blocks: Optional[list[KVCacheBlock]] = None, + num_new_computed_tokens: int = 0, ) -> Optional[list[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -173,6 +174,8 @@ def allocate_slots( not include the tokens that have already been computed. new_computed_blocks: A list of new computed blocks just hitting the prefix caching. + num_new_computed_tokens: The number of new computed tokens in the + new_computed_blocks. Blocks layout: ----------------------------------------------------------------------- @@ -192,7 +195,9 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] + new_computed_blocks = new_computed_blocks or [ + [] for _ in range(self.num_kv_cache_groups) + ] req_blocks = self.req_to_blocks[request.request_id] @@ -202,32 +207,47 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + removed_blocks = [ + manager.remove_skipped_blocks(req_blocks, + request.num_computed_tokens) + for manager in self.specialized_managers + ] + self._free_blocks(removed_blocks) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) - num_required_blocks = cdiv(num_computed_tokens + num_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + num_new_computed_tokens) + + num_new_blocks: list[int] = [] + for i in range(self.num_kv_cache_groups): + num_required_blocks_i = cdiv( + num_computed_tokens + num_tokens, + self.specialized_managers[i].block_size) + num_new_blocks.append((num_required_blocks_i - len(req_blocks[i]) - + len(new_computed_blocks[i]))) + total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks = sum( + 1 for blk_one_layer in new_computed_blocks for blk in blk_one_layer + if blk.ref_cnt == 0) + if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + for blocks in new_computed_blocks: + self.block_pool.touch(blocks) + else: + assert all(len(blks) == 0 for blks in new_computed_blocks), ( + "Computed blocks should be empty when " + "prefix caching is disabled") else: assert not new_computed_blocks, ( "Computed blocks should be empty when " @@ -235,55 +255,77 @@ def allocate_slots( # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + for i in range(self.num_kv_cache_groups): + req_blocks[i].extend(new_computed_blocks[i]) # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + new_blocks: list[list[KVCacheBlock]] = [] + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each layer. + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.block_pool.get_num_free_blocks() - total_num_new_blocks) // + len(self.specialized_managers)) + + for i in range(self.num_kv_cache_groups): + if num_new_blocks[i] <= 0: + # No new block is needed. + new_blocks.append([]) + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_block_to_allocate = min( + num_new_blocks[i] + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + # Don't need self.block_pool.get_num_free_blocks() as in + # KVCacheManager because we already considered it when + # calculating num_preallocate_blocks + self.max_num_blocks_per_req[i] - len(req_blocks[i]), + ) + + assert num_block_to_allocate > 0 + assert num_block_to_allocate <= \ + self.block_pool.get_num_free_blocks() + + # Concatenate the computed block IDs and the new block IDs. + new_blocks_this_layer = self.block_pool.get_new_blocks( + num_block_to_allocate) + new_blocks.append(new_blocks_this_layer) + req_blocks[i].extend(new_blocks_this_layer) if not self.enable_caching: return new_blocks # Use `new_computed_blocks` for a new request, and `num_cached_block` # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) + num_cached_blocks = self.num_cached_block.get( + request.request_id, + [len(blocks) for blocks in new_computed_blocks]) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) - - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append + for i in range(self.num_kv_cache_groups): + num_full_blocks_after_append = ( + num_computed_tokens + num_tokens - len(request.spec_token_ids) + ) // self.specialized_managers[i].block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=self.req_to_block_hashes[request.request_id], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.specialized_managers[i].block_size, + hash_fn=self.caching_hash_fn, + kv_cache_group_id=i, + ) + num_cached_blocks[i] = num_full_blocks_after_append + + self.num_cached_block[request.request_id] = num_cached_blocks return new_blocks def free(self, request: Request) -> None: @@ -295,35 +337,21 @@ def free(self, request: Request) -> None: request: The request to free the blocks. """ # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) + blocks = self.req_to_blocks.pop(request.request_id, None) + if blocks is not None: + # Reverse the blocks so that the tail blocks can have higher + # eviction priority. + self._free_blocks([list(reversed(blks)) for blks in blocks]) - self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request.request_id, None) - def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - if self.block_pool.reset_prefix_cache(): - self.prefix_cache_stats.reset = True - return True - return False + reset_prefix_cache = KVCacheManager.reset_prefix_cache def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> list[int]: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. @@ -355,22 +383,96 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. + list[int]: The number of common prefix blocks for each kv cache + group. """ assert request.status == RequestStatus.RUNNING blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break + num_common_blocks = [] + for i in range(self.num_kv_cache_groups): + num_common_blocks_i = 0 + for block in blocks[i]: + if block.ref_cnt == num_running_requests: + num_common_blocks_i += 1 + else: + break + num_common_blocks.append(num_common_blocks_i) return num_common_blocks - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. + free_block_hashes = KVCacheManager.free_block_hashes - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. + def find_longest_cache_hit( + self, request_id: int, block_hashes: list[list[BlockHashType]] + ) -> tuple[list[list[KVCacheBlock]], int]: + """Find the longest cache hit for each kv cache group. + TODO: add more notes """ - self.req_to_block_hashes.pop(request.request_id, None) + # TODO: accelerate by make full attention the first layer + # TODO: add note for the two magic number + num_computed_tokens = [self.max_model_len + 100] * len( + self.specialized_managers) + min_computed_tokens = self.max_model_len + + # Use copy to avoid modifying the original block_hashes + block_hashes = [block_hash.copy() for block_hash in block_hashes] + + while not max(num_computed_tokens) == min_computed_tokens: + for i, manager in enumerate(self.specialized_managers): + if num_computed_tokens[i] > min_computed_tokens: + del block_hashes[i][:min_computed_tokens // + manager.block_size] + computed_blocks_group_i = manager.find_longest_cache_hit( + request_id, block_hashes[i], return_const_list=True) + + num_computed_tokens[i] = len(computed_blocks_group_i) * \ + manager.block_size + min_computed_tokens = min(min_computed_tokens, + num_computed_tokens[i]) + + # Get the non-constlist computed blocks + computed_blocks = [ + manager.find_longest_cache_hit(request_id, + block_hashes[i], + return_const_list=False) + for i, manager in enumerate(self.specialized_managers) + ] + + assert all( + len(block) * manager.block_size == min_computed_tokens for block, + manager in zip(computed_blocks, self.specialized_managers)) + + return computed_blocks, min_computed_tokens + + def _merge_blocks_by_eviction_order( + self, blocks: list[list[KVCacheBlock]]) -> list[KVCacheBlock]: + """ + Merge the blocks of different layers to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. + Args: + blocks: the blocks of each virtual layer, ordered by eviction + priority. + Returns: + A list of KVCacheBlocks sorted by eviction order. + """ + + if self.enable_caching: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # each layer. We can investigate more advanced strategies + # in the future. + ordered_blocks = [] + max_len = max(len(blocks_one_layer) for blocks_one_layer in blocks) + for i in range(max_len): + for blocks_one_layer in blocks: + if i < len(blocks_one_layer): + ordered_blocks.append(blocks_one_layer[i]) + else: + ordered_blocks = [] + for blocks_one_layer in blocks: + ordered_blocks.extend(blocks_one_layer) + + return ordered_blocks + + def _free_blocks(self, blocks: list[list[KVCacheBlock]]) -> None: + ordered_blocks = self._merge_blocks_by_eviction_order(blocks) + self.block_pool.free_blocks(ordered_blocks) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c0f7715209d1..bea01d399a6d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union from vllm.logger import init_logger from vllm.utils import cdiv, sha256 @@ -13,6 +13,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus +if TYPE_CHECKING: + from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager logger = init_logger(__name__) @@ -140,8 +142,8 @@ def get_computed_blocks( else: last_block_hash = None - computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) + computed_blocks = (self.specialized_manager.find_longest_cache_hit( + request.request_id, block_hashes)) if last_block_hash is not None: # Add back the last block hash if it was removed. @@ -163,7 +165,8 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None + new_computed_blocks: Optional[list[KVCacheBlock]] = None, + num_new_computed_tokens: int = 0, ) -> Optional[list[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -173,6 +176,8 @@ def allocate_slots( not include the tokens that have already been computed. new_computed_blocks: A list of new computed blocks just hitting the prefix caching. + num_new_computed_tokens: The number of new computed tokens in the + new_computed_blocks. Blocks layout: ----------------------------------------------------------------------- @@ -209,7 +214,7 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) + num_new_computed_tokens) num_required_blocks = cdiv(num_computed_tokens + num_tokens, self.block_size) num_new_blocks = (num_required_blocks - len(req_blocks) - @@ -374,3 +379,26 @@ def free_block_hashes(self, request: Request) -> None: is finished, not when it is preempted. """ self.req_to_block_hashes.pop(request.request_id, None) + + +def init_kv_cache_manager( + kv_cache_config: KVCacheConfig, + max_model_len: int, + enable_caching: bool = True, + num_preallocate_tokens: int = 64 +) -> Union[KVCacheManager, "HybridKVCacheManager"]: + from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager + if len(kv_cache_config.kv_cache_groups) > 1: + return HybridKVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=max_model_len, + enable_caching=enable_caching, + num_preallocate_tokens=num_preallocate_tokens, + ) + else: + return KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=max_model_len, + enable_caching=enable_caching, + num_preallocate_tokens=num_preallocate_tokens, + ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 34bc9369b125..81dcd985878f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" +import math import os -from collections import deque +from collections import defaultdict, deque from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -10,8 +11,9 @@ from vllm.logger import init_logger from vllm.utils import sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) + KVCacheGroupSpec, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -263,11 +265,12 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]: return ret -def need_extra_keys(request: Request) -> bool: +def need_extra_keys(request: Request, kv_cache_group_id: int) -> bool: """Check whether the blocks allocated to this request need extra hash keys. Args: request (Request): The request. + kv_cache_group_id (int): The id of the kv cache group. -1 means no kv cache group. Returns: bool: Whether blocks allocated to this request need extra hash keys. @@ -275,7 +278,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. - return bool(request.mm_positions) or (request.lora_request is not None) + return bool(request.mm_positions) or (request.lora_request + is not None) or (kv_cache_group_id + != -1) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -364,7 +369,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + start_mm_idx: int, + kv_cache_group_id: int) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -373,7 +379,8 @@ def generate_block_hash_extra_keys( start_token_idx: The start token index of the block. end_token_idx: The end token index of the block. start_mm_idx: The start multi-modal index of the block. - + kv_cache_group_id: The id of the kv cache group. -1 means no kv cache + group. Returns: A tuple of extra keys and the next multi-modal index. """ @@ -383,6 +390,8 @@ def generate_block_hash_extra_keys( lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + if kv_cache_group_id != -1: + extra_keys.append(kv_cache_group_id) if not extra_keys: return None, new_start_mm_idx @@ -401,10 +410,12 @@ def hash_block_tokens( hash values for the same block contents. Args: + hash_function: The function used for hash parent_block_hash: The hash of the parent block. None if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. + kv_cache_group_id: The id of the kv cache group. -1 means no kv cache group. extra_keys: Extra keys for the block. Returns: @@ -421,21 +432,24 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHashType]: +def hash_request_tokens(hash_function: Any, + block_size: int, + request: Request, + kv_cache_group_id: int = -1) -> list[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. request: The request object. + kv_cache_group_id: The id of the kv cache group. -1 means no kv cache group. Returns: The list of computed hash values. """ token_ids = request.all_token_ids - req_need_extra_keys = need_extra_keys(request) + req_need_extra_keys = need_extra_keys(request, kv_cache_group_id) req_extra_keys = None curr_mm_idx = 0 @@ -451,7 +465,7 @@ def hash_request_tokens(hash_function: Any, block_size: int, if req_need_extra_keys: # MM and LoRA requests need extra keys for block-hash computation. req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx) + request, start, end, curr_mm_idx, kv_cache_group_id) block_hash = hash_block_tokens(hash_function, parent_block_hash_value, block_token_ids, req_extra_keys) @@ -589,7 +603,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_config = KVCacheConfig( num_blocks=num_blocks, tensors={ - layer_name: KVCacheTensor(size=per_layer_size) + layer_name: KVCacheNewTensor(size=per_layer_size) for layer_name in kv_cache_spec }, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, @@ -598,6 +612,82 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def is_kv_cache_page_size_uniform( + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one page size. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 4 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2, sw.3). + same_type_layers: dict[str, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in + # each group identical. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2, sw.3), group_size_gcd is 2, + # split to 3 groups with 2 layers each: + # (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3). + group_size_gcd = math.gcd( + *[len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + for i in range(0, len(layers), group_size_gcd): + grouped_layers.append(layers[i:i + group_size_gcd]) + + # Divide the available memory equally among all layers in the first group. + # The memory layout in the example will be: + # full.0: Tensor with size=available_memory//2 + # full.1: Tensor with size=available_memory//2 + kv_cache_spec_first_group = { + layer_name: kv_cache_spec[layer_name] + for layer_name in grouped_layers[0] + } + kv_cache_config = _get_kv_cache_config_uniform_type( + vllm_config, kv_cache_spec_first_group, available_memory) + + # Reuse the KV cache tensors of the first group for the other groups. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1, sw.3: share another Tensor with size=available_memory//2 + # Layers of different groups have different block table, so they will + # use different parts of the shared Tensor. + for layers in grouped_layers[1:]: + for layer_name, layer_name_first_group in zip(layers, + grouped_layers[0]): + kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( + reused_layer_name=layer_name_first_group) + + kv_cache_config.kv_cache_groups = create_kv_cache_group_specs( + kv_cache_spec, grouped_layers) + return kv_cache_config + + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ Only models with one type of KV cache are supported yet. This function tries @@ -618,10 +708,12 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( block_size=spec.block_size, + num_query_heads=spec.num_query_heads, num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, + compute_as_sliding_window=True, ) @@ -641,14 +733,19 @@ def get_kv_cache_config(vllm_config: VllmConfig, The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - unify_hybrid_kv_cache_specs(kv_cache_spec) + if vllm_config.cache_config.disable_hybrid_allocator: + unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) - + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # KV cache of all layers have the same page size. TODO more notes + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) raise NotImplementedError diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dc0d2d59fea7..c4aa691a7138 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from vllm.v1.kv_cache_interface import MayMultiGroupBlockIDs + if TYPE_CHECKING: import numpy as np import numpy.typing as npt @@ -25,7 +27,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[int] + block_ids: MayMultiGroupBlockIDs num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -33,7 +35,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[int], + block_ids: MayMultiGroupBlockIDs, ) -> NewRequestData: return cls( req_id=request.request_id, @@ -58,7 +60,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[int] + new_block_ids: MayMultiGroupBlockIDs num_computed_tokens: int @classmethod @@ -67,7 +69,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[int], + new_block_ids: MayMultiGroupBlockIDs, ) -> CachedRequestData: return cls( req_id=request.request_id, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4d477567b9b6..c670fa3a618b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -12,14 +12,15 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import init_kv_cache_manager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import (BlockIDGenerator, KVCacheConfig, + MayMultiGroupBlockIDs) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -62,13 +63,15 @@ def __init__( self.max_model_len = self.scheduler_config.max_model_len # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( + self.kv_cache_manager = init_kv_cache_manager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, log_stats=self.log_stats) self.block_size = self.cache_config.block_size + BlockIDGenerator.num_kv_cache_groups = len( + self.kv_cache_config.kv_cache_groups) # req_id -> Request self.requests: dict[str, Request] = {} @@ -136,7 +139,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[int]] = {} + req_to_new_block_ids: dict[str, MayMultiGroupBlockIDs] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -222,9 +225,9 @@ def schedule(self) -> SchedulerOutput: # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[ + request.request_id] = BlockIDGenerator.from_kv_cache_blocks( + new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -320,7 +323,8 @@ def schedule(self) -> SchedulerOutput: new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens, computed_blocks, + num_computed_tokens) if new_blocks is None: # The request cannot be scheduled. break @@ -345,9 +349,10 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[ + request.request_id] = BlockIDGenerator.generate( + computed_blocks) + BlockIDGenerator.generate( + new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -379,7 +384,11 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + if len(self.kv_cache_config.kv_cache_groups) > 1: + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + else: + num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -453,7 +462,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[int], + new_block_ids: MayMultiGroupBlockIDs, resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 7a8a98361c7e..b7e36f776cf8 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections import namedtuple +from typing import Optional, Union from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.utils import ConstantList class SpecializedManager(ABC): @@ -29,16 +32,56 @@ def __init__( self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + # for caching the intermediate states between multiple calls of + # find_longest_cache_hit + self.req_cached_blocks: dict[int, list[KVCacheBlock]] = {} - @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + self, + request_id: int, + block_hashes: list[BlockHashType], + return_const_list: bool = False, + ) -> Union[list[KVCacheBlock], ConstantList[KVCacheBlock]]: """ - Get the longest cache hit prefix of the blocks. If no cache hit is + Find the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. Args: + request_id: The request id. block_hashes: The block hashes of the request. + return_const_list: Whether to return a ConstantList. + """ + + if req_cached_blocks := self.req_cached_blocks.pop(request_id, None): + assert len(req_cached_blocks) >= len(block_hashes) + + req_cached_blocks = self._find_longest_cache_hit( + block_hashes, req_cached_blocks) + + if return_const_list: + # TODO: add comment + self.req_cached_blocks[request_id] = req_cached_blocks + return ConstantList(req_cached_blocks) + else: + # TODO: add comment + return req_cached_blocks + + @abstractmethod + def _find_longest_cache_hit( + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[KVCacheBlock]], + ) -> list[KVCacheBlock]: + """ + # TODO: update comment for multiple calls + Get the longest cache hit prefix of the blocks. If no cache hit is + found, return an empty list. # TODO: add notes for computed_blocks + will not be longer than block_hashes. + + Args: + block_hashes: The block hashes of the request. + computed_blocks: The cached blocks for the request returned from + the previous call of this function. Returns: A list of cached blocks with skipped blocks replaced by null block. For example, sliding window manager should return a list like @@ -68,17 +111,24 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class FullAttentionManager(SpecializedManager): - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break + def _find_longest_cache_hit( + self, block_hashes: list[BlockHashType], + computed_blocks: Optional[list[KVCacheBlock]] + ) -> list[KVCacheBlock]: + if computed_blocks is None: + computed_blocks: list[KVCacheBlock] = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block( + block_hash): + computed_blocks.append(cached_block) + else: + break + else: + assert len(computed_blocks) >= len(block_hashes) + del computed_blocks[len(block_hashes):] return computed_blocks def remove_skipped_blocks(self, blocks: list[KVCacheBlock], @@ -99,18 +149,30 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, (kv_cache_spec.sliding_window - 1), self.block_size) self._null_block = block_pool.null_block - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def _find_longest_cache_hit( + self, block_hashes: list[BlockHashType], + computed_blocks: Optional[list[KVCacheBlock]] + ) -> list[KVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) - num_contiguous_blocks = 0 + if computed_blocks is None: + num_contiguous_blocks = 0 + computed_blocks = [self._null_block] * len(block_hashes) + else: + if len(computed_blocks) == len(block_hashes): + return computed_blocks + # We are sure the last num_contiguous_blocks are not NULL and do + # not need to check again. + num_contiguous_blocks = max( + self.sliding_window_contiguous_blocks - + (len(computed_blocks) - len(block_hashes)), 0) + del computed_blocks[len(block_hashes):] # Search from right to left and early stop when a match is found. - for i in range(len(block_hashes) - 1, -1, -1): + for i in range(len(block_hashes) - num_contiguous_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( block_hashes[i]): computed_blocks[i] = cached_block diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4fc0844cd1f4..8af1130aa804 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import TYPE_CHECKING, Union, cast, overload, Type import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import KVCacheBlock logger = init_logger(__name__) @@ -56,8 +59,9 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass class AttentionSpec(KVCacheSpec): - num_kv_heads: int head_size: int + num_query_heads: int + num_kv_heads: int dtype: torch.dtype use_mla: bool @@ -71,6 +75,8 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): + # TODO: add note + compute_as_sliding_window: bool = False @property def type_id(self) -> str: @@ -112,15 +118,30 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class KVCacheTensor: +class KVCacheTensorBase: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + for a layer. + """ + pass + + +@dataclass +class KVCacheNewTensor(KVCacheTensorBase): + """ + Initialize the KV cache with a tensor of `size` bytes. """ size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheReuseTensor(KVCacheTensorBase): + """ + Reuse the KV cache tensor of `layer_name` for the current layer. + """ + reused_layer_name: str + + @dataclass class KVCacheGroupSpec: """ @@ -141,7 +162,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: dict[str, KVCacheTensor] + tensors: dict[str, KVCacheTensorBase] """ The kv cache groups of the model. The layers in the models are repeated with some patterns, e.g., a model @@ -164,3 +185,63 @@ class KVCacheConfig: there are 3 groups, each of which represents 10 layers in the model. """ kv_cache_groups: list[KVCacheGroupSpec] + + +@dataclass +class MultiGroupBlockIDs: + # A list of block IDs for each virtual layer + _block_ids: list[list[int]] + + def __init__(self, block_ids: list[list[int]]): + self._block_ids = block_ids + + @classmethod + def from_kv_cache_blocks(cls, kv_cache_blocks: list[list["KVCacheBlock"]]): + return cls( + block_ids=[[blk.block_id for blk in kv_cache_blocks_one_layer] + for kv_cache_blocks_one_layer in kv_cache_blocks]) + + def extend(self, new_block_ids: "MultiGroupBlockIDs") -> None: + for i, block_ids in enumerate(new_block_ids._block_ids): + self._block_ids[i].extend(block_ids) + + def __add__(self, other: "MultiGroupBlockIDs") -> "MultiGroupBlockIDs": + return MultiGroupBlockIDs(block_ids=[ + a + b for a, b in zip(self._block_ids, other._block_ids) + ]) + + def get_block_id_of_group(self, group_id: int) -> list[int]: + return self._block_ids[group_id] + + +MayMultiGroupBlockIDs = Union[MultiGroupBlockIDs, list[int]] +MayMultiGroupInt = Union[int, list[int]] + + +class BlockIDGenerator: + num_kv_cache_groups: int + + @overload + @classmethod + def from_kv_cache_blocks( + cls, kv_cache_blocks: list["KVCacheBlock"]) -> list[int]: + ... + + @overload + @classmethod + def from_kv_cache_blocks( + cls, + kv_cache_blocks: list[list["KVCacheBlock"]]) -> MultiGroupBlockIDs: + ... + + @classmethod + def from_kv_cache_blocks( + cls, kv_cache_blocks: Union[list["KVCacheBlock"], + list[list["KVCacheBlock"]]] + ) -> MayMultiGroupBlockIDs: + if cls.num_kv_cache_groups == 1: + kv_cache_blocks = cast(list["KVCacheBlock"], kv_cache_blocks) + return [blk.block_id for blk in kv_cache_blocks] + else: + kv_cache_blocks = cast(list[list["KVCacheBlock"]], kv_cache_blocks) + return MultiGroupBlockIDs.from_kv_cache_blocks(kv_cache_blocks) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992..89daddfd3169 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Concatenate, ParamSpec, Union import numpy as np import torch from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -14,11 +17,13 @@ def __init__( self, max_num_reqs: int, max_num_blocks_per_req: int, + max_num_tokens: int, # TODO pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_tokens = max_num_tokens self.pin_memory = pin_memory self.device = device @@ -36,6 +41,12 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + def append_row( self, block_ids: list[int], @@ -85,3 +96,75 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +P = ParamSpec("P") + + +class MultiLayerBlockTable: + move_row: Callable[P, None] + commit: Callable[P, None] + clear: Callable[P, None] + + append_row: Callable[Concatenate["MayMultiLayerBlockIDs", P], None] + add_row: Callable[Concatenate["MayMultiLayerBlockIDs", P], None] + + def __init__(self, max_num_reqs: int, max_num_blocks_per_req: list[int], + max_num_tokens: int, pin_memory: bool, device: torch.device, + kv_cache_config: KVCacheConfig) -> None: + self.block_tables = [ + BlockTable(max_num_reqs, max_num_blocks_per_req[i], max_num_tokens, + pin_memory, device) + for i in range(len(kv_cache_config.kv_cache_groups)) + ] + # For methods that just pass the arguments to each BlockTable. + for f_name in ("move_row", "swap_row", "commit", "clear"): + setattr(self, f_name, self._make_broadcast_func(f_name)) + # For methods that require a block_ids as the first argument. + for f_name in ("append_row", "add_row"): + setattr(self, f_name, + self._make_broadcast_func_with_block_ids(f_name)) + + def _make_broadcast_func(self, f_name: str) -> Callable[P, None]: + + def broadcast_func(*args: P.args, **kwargs: P.kwargs) -> None: + for block_table in self.block_tables: + getattr(block_table, f_name)(*args, **kwargs) + + return broadcast_func + + def _make_broadcast_func_with_block_ids( + self, f_name: str + ) -> Callable[Concatenate["MayMultiLayerBlockIDs", P], None]: + + def broadcast_func(block_ids: "MayMultiLayerBlockIDs", *args: P.args, + **kwargs: P.kwargs) -> None: + for i, block_table in enumerate(self.block_tables): + getattr(block_table, f_name)(block_ids.get_virtual_layer(i), + *args, **kwargs) + + return broadcast_func + + def __getitem__(self, idx: int) -> "BlockTable": + return self.block_tables[idx] + + +def initialize_block_table( + max_num_reqs: int, + max_model_len: int, + max_num_tokens: int, + pin_memory: bool, + device: torch.device, + kv_cache_config: KVCacheConfig, +) -> Union[BlockTable, MultiLayerBlockTable]: + max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] + if len(kv_cache_config.kv_cache_groups) == 1: + return BlockTable(max_num_reqs, max_num_blocks_per_req[0], + max_num_tokens, pin_memory, device) + else: + return MultiLayerBlockTable(max_num_reqs, max_num_blocks_per_req, + max_num_tokens, pin_memory, device, + kv_cache_config) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 351b35815580..6b3e356fa591 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,10 +11,11 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import BlockTable, initialize_block_table _SAMPLING_EPS = 1e-5 @@ -50,14 +51,14 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, + kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -89,11 +90,13 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = BlockTable( + self.block_table = initialize_block_table( max_num_reqs=max_num_reqs, - max_num_blocks_per_req=max_num_blocks_per_req, + max_model_len=max_model_len, + max_num_tokens=max_num_tokens, pin_memory=pin_memory, device=device, + kv_cache_config=kv_cache_config, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 637367a70d2a..2945f5f0ee11 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Type, Union, cast import numpy as np import torch @@ -11,6 +11,7 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadataBuilder from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import get_pp_group, graph_capture @@ -39,6 +40,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -89,50 +91,14 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - # NOTE(woosuk): sliding_window is None for models with interleaved - # attention. Use interleaved_sliding_window instead. - self.sliding_window = model_config.get_sliding_window() - self.interleaved_sliding_window = getattr( - model_config.hf_text_config, "interleaved_sliding_window", None) - self.window_size = (self.sliding_window - or self.interleaved_sliding_window) - self.is_multimodal_model = model_config.is_multimodal_model - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -149,7 +115,14 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model + # init in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + self.kv_cache_config = cast(KVCacheConfig, None) + self.attn_backends: list[Type[AttentionBackend]] = [] + self.attn_metadata_builders: list[Type[AttentionMetadataBuilder]] = [] + # Persistent batch + self.input_batch = cast(InputBatch, None) + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -173,15 +146,6 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -535,23 +499,35 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + if len(self.kv_cache_config.kv_cache_groups) == 1: + may_multi_layer_unwrapper = lambda x, _group_id: x + else: + may_multi_layer_unwrapper = lambda x, group_id: x[group_id] + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table: BlockTable = may_multi_layer_unwrapper( + self.input_batch.block_table, kv_cache_group_id) + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = ( + block_table_cpu.flatten()[block_table_indices].numpy()) + block_offsets = positions_np % block_size + np.add(block_numbers * block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -575,20 +551,25 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + attn_metadata: dict[str, FlashAttentionMetadata] = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + kv_cache_group_spec.kv_cache_spec, + ) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - ) + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + ) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -624,6 +605,7 @@ def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -642,7 +624,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * self.block_size + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. return 0 @@ -691,15 +673,18 @@ def _compute_cascade_attn_prefix_len( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.compute_as_sliding_window)) use_cascade = self.attn_backend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, + num_query_heads=kv_cache_spec.num_query_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=self.window_size is not None, + use_sliding_window=use_sliding_window, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1552,14 +1537,24 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") + self.kv_cache_config = kv_cache_config + self.initialize_attn_backend(kv_cache_config) + self.initialize_kv_cache_tensors(kv_cache_config) + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + def initialize_kv_cache_tensors(self, + kv_cache_config: KVCacheConfig) -> None: kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.kv_cache_groups: + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] @@ -1574,7 +1569,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype @@ -1591,6 +1586,40 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + # TODO: docstring + assert (len(self.attn_backends) == 0 + and len(self.attn_metadata_builders) == 0, + "already initialized") + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + raise NotImplementedError( + "Only AttentionSpec is supported for now.") + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=(isinstance(kv_cache_spec, AttentionSpec) + and kv_cache_spec.use_mla), + ) + if attn_backend_i is None: + error_msg = ( + f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.kv_cache_dtype=}, {kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner." + ) + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + weakref.proxy(self)) + self.attn_backends.append(attn_backend_i) + self.attn_metadata_builders.append(attn_metadata_builder_i) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -1615,16 +1644,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, + num_query_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, + num_query_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, dtype=self.kv_cache_dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, From 4dea38d9303b100272cbf4f4dece0c09bd38c349 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 10:21:37 +0000 Subject: [PATCH 03/34] can run Signed-off-by: Chen Zhang --- vllm/v1/core/hybrid_kv_cache_manager.py | 6 +- vllm/v1/core/kv_cache_manager.py | 8 +- vllm/v1/core/sched/scheduler.py | 4 +- vllm/v1/worker/block_table.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 160 ++++++++++++++++-------- 5 files changed, 125 insertions(+), 58 deletions(-) diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 9dbf1bfd3606..05e1d478677a 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -315,9 +315,9 @@ def allocate_slots( self.block_pool.cache_full_blocks( request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, + blocks=req_blocks[i], + block_hashes=self.req_to_block_hashes[request.request_id][i], + num_cached_blocks=num_cached_blocks[i], num_full_blocks=num_full_blocks_after_append, block_size=self.specialized_managers[i].block_size, hash_fn=self.caching_hash_fn, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bea01d399a6d..af06f9919267 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -385,7 +385,9 @@ def init_kv_cache_manager( kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, - num_preallocate_tokens: int = 64 + caching_hash_algo: str = "builtin", + num_preallocate_tokens: int = 64, + log_stats: bool = False, ) -> Union[KVCacheManager, "HybridKVCacheManager"]: from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager if len(kv_cache_config.kv_cache_groups) > 1: @@ -393,12 +395,16 @@ def init_kv_cache_manager( kv_cache_config=kv_cache_config, max_model_len=max_model_len, enable_caching=enable_caching, + caching_hash_algo=caching_hash_algo, num_preallocate_tokens=num_preallocate_tokens, + log_stats=log_stats, ) else: return KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=max_model_len, enable_caching=enable_caching, + caching_hash_algo=caching_hash_algo, num_preallocate_tokens=num_preallocate_tokens, + log_stats=log_stats, ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c670fa3a618b..856cc93f4df7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -350,8 +350,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[ - request.request_id] = BlockIDGenerator.generate( - computed_blocks) + BlockIDGenerator.generate( + request.request_id] = BlockIDGenerator.from_kv_cache_blocks( + computed_blocks) + BlockIDGenerator.from_kv_cache_blocks( new_blocks) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 89daddfd3169..18a4332e5630 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -140,8 +140,9 @@ def _make_broadcast_func_with_block_ids( def broadcast_func(block_ids: "MayMultiLayerBlockIDs", *args: P.args, **kwargs: P.kwargs) -> None: for i, block_table in enumerate(self.block_tables): - getattr(block_table, f_name)(block_ids.get_virtual_layer(i), - *args, **kwargs) + getattr(block_table, + f_name)(block_ids.get_block_id_of_group(i), *args, + **kwargs) return broadcast_func diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2945f5f0ee11..51f4764b74d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) + SlidingWindowSpec, KVCacheNewTensor, + KVCacheReuseTensor) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata @@ -437,7 +438,8 @@ def _prepare_inputs( # Some attention backends (namely MLA) may want to separate requests # based on if the attention computation will be compute-bound or # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( + # NOTE: only same builder is supported now + modified_batch = self.attn_metadata_builders[0].reorder_batch( self.input_batch, scheduler_output) if modified_batch: self.input_batch.refresh_sampling_metadata() @@ -521,7 +523,7 @@ def _prepare_inputs( # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_table_cpu = block_table.get_cpu_tensor() block_numbers = ( block_table_cpu.flatten()[block_table_indices].numpy()) block_offsets = positions_np % block_size @@ -560,15 +562,20 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], kv_cache_group_spec.kv_cache_spec, + self.attn_backends[kv_cache_group_id], ) - attn_metadata = self.attn_metadata_builder.build( + block_table = may_multi_layer_unwrapper( + self.input_batch.block_table, kv_cache_group_id) + attn_metadata = self.attn_metadata_builders[0].build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, + block_table=block_table, ) use_spec_decode = len( @@ -606,6 +613,7 @@ def _compute_cascade_attn_prefix_len( num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, + attn_backend: AttentionBackend, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -678,7 +686,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or (isinstance(kv_cache_spec, FullAttentionSpec) and kv_cache_spec.compute_as_sliding_window)) - use_cascade = self.attn_backend.use_cascade_attention( + use_cascade = attn_backend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=kv_cache_spec.num_query_heads, @@ -1530,57 +1538,77 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + def _initialize_kv_cache_buffer( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ - Initialize KV cache based on `kv_cache_config`. + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheNewTensor): + # A new tensor with `tensor_config.size` bytes + kv_cache_raw_tensors[layer_name] = torch.zeros( + tensor_config.size, dtype=torch.int8, device=self.device) + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheReuseTensor): + # Reuse a tensor from `kv_cache_raw_tensors` + kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ + tensor_config.reused_layer_name] + assert len(kv_cache_raw_tensors) == len( + kv_cache_config.tensors), "Some layers are not initialized" + return kv_cache_raw_tensors + + def _setup_kv_cache_shapes( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape. + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. """ - self.kv_cache_config = kv_cache_config - self.initialize_attn_backend(kv_cache_config) - self.initialize_kv_cache_tensors(kv_cache_config) - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - - def initialize_kv_cache_tensors(self, - kv_cache_config: KVCacheConfig) -> None: kv_caches: dict[str, torch.Tensor] = {} - - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - # `num_blocks` is the number of blocks the model runner can use. - # `kv_cache_config.num_blocks` is the number of blocks that - # KVCacheManager may allocate. - # Since different GPUs may have different number of layers and - # different memory capacities, `num_blocks` can be different on - # different GPUs, and `kv_cache_config.num_blocks` is set to - # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks - if isinstance(kv_cache_spec, AttentionSpec): + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, + (FullAttentionSpec, SlidingWindowSpec)): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape) else: - # TODO: add new branches when introducing more types of - # KV cache specs. - raise ValueError("Unknown KV cache spec type.") + raise NotImplementedError + return kv_caches + def initialize_kv_cache_tensors(self, + kv_cache_config: KVCacheConfig) -> None: + # TODO: docstring + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._initialize_kv_cache_buffer( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._setup_kv_cache_shapes(kv_cache_config, + kv_cache_raw_tensors) bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, @@ -1588,9 +1616,8 @@ def initialize_kv_cache_tensors(self, def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # TODO: docstring - assert (len(self.attn_backends) == 0 - and len(self.attn_metadata_builders) == 0, - "already initialized") + assert len(self.attn_backends) == 0 and len( + self.attn_metadata_builders) == 0, "already initialized" for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec if not isinstance(kv_cache_spec, AttentionSpec): @@ -1620,6 +1647,39 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + print("reorder_batch", [ + id(builder.reorder_batch.__func__) + for builder in self.attn_metadata_builders + ]) + print("is check", [ + builder.reorder_batch.__func__ + is self.attn_metadata_builders[0].reorder_batch.__func__ + for builder in self.attn_metadata_builders + ]) + assert all(builder.reorder_batch.__func__ is + self.attn_metadata_builders[0].reorder_batch.__func__ + for builder in self.attn_metadata_builders), "TODO" + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + self.kv_cache_config = kv_cache_config + self.initialize_attn_backend(kv_cache_config) + self.initialize_kv_cache_tensors(kv_cache_config) + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each From 55720e0449ea374d63b1df93242ff5554794af01 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 06:37:11 -0700 Subject: [PATCH 04/34] can pass e2e tests Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/core/hybrid_kv_cache_manager.py | 15 ++---- vllm/v1/core/sched/scheduler.py | 7 ++- vllm/v1/worker/block_table.py | 14 +++-- vllm/v1/worker/gpu_model_runner.py | 67 +++++++++++------------- 5 files changed, 48 insertions(+), 59 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 47af053a6837..85a5dd77624e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,10 +11,10 @@ AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.core.block.block_table import BlockTable from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) @@ -115,7 +115,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, non_blocking=True) block_table_tensor = block_table.get_device_tensor()[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() use_cascade = common_prefix_len > 0 diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 05e1d478677a..4ff76b986988 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Iterable from typing import Optional from vllm.logger import init_logger @@ -208,9 +207,9 @@ def allocate_slots( # Should call this function before allocating new blocks to reduce # the number of evicted blocks. removed_blocks = [ - manager.remove_skipped_blocks(req_blocks, + manager.remove_skipped_blocks(req_blocks[i], request.num_computed_tokens) - for manager in self.specialized_managers + for i, manager in enumerate(self.specialized_managers) ] self._free_blocks(removed_blocks) @@ -225,8 +224,8 @@ def allocate_slots( num_required_blocks_i = cdiv( num_computed_tokens + num_tokens, self.specialized_managers[i].block_size) - num_new_blocks.append((num_required_blocks_i - len(req_blocks[i]) - - len(new_computed_blocks[i]))) + num_new_blocks.append(num_required_blocks_i - len(req_blocks[i]) - + len(new_computed_blocks[i])) total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) # If a computed block of a request is an eviction candidate (in the @@ -244,12 +243,8 @@ def allocate_slots( if self.enable_caching: for blocks in new_computed_blocks: self.block_pool.touch(blocks) - else: - assert all(len(blks) == 0 for blks in new_computed_blocks), ( - "Computed blocks should be empty when " - "prefix caching is disabled") else: - assert not new_computed_blocks, ( + assert all(len(blks) == 0 for blks in new_computed_blocks), ( "Computed blocks should be empty when " "prefix caching is disabled") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 856cc93f4df7..794dfd01ba66 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -349,10 +349,9 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[ - request.request_id] = BlockIDGenerator.from_kv_cache_blocks( - computed_blocks) + BlockIDGenerator.from_kv_cache_blocks( - new_blocks) + req_to_new_block_ids[request.request_id] = ( + BlockIDGenerator.from_kv_cache_blocks(computed_blocks) + + BlockIDGenerator.from_kv_cache_blocks(new_blocks)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 18a4332e5630..fd8583972b59 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Callable, Concatenate, ParamSpec, Union + import numpy as np import torch @@ -106,8 +107,8 @@ class MultiLayerBlockTable: commit: Callable[P, None] clear: Callable[P, None] - append_row: Callable[Concatenate["MayMultiLayerBlockIDs", P], None] - add_row: Callable[Concatenate["MayMultiLayerBlockIDs", P], None] + append_row: Callable[Concatenate[list[int], P], None] + add_row: Callable[Concatenate[list[int], P], None] def __init__(self, max_num_reqs: int, max_num_blocks_per_req: list[int], max_num_tokens: int, pin_memory: bool, device: torch.device, @@ -134,15 +135,12 @@ def broadcast_func(*args: P.args, **kwargs: P.kwargs) -> None: return broadcast_func def _make_broadcast_func_with_block_ids( - self, f_name: str - ) -> Callable[Concatenate["MayMultiLayerBlockIDs", P], None]: + self, f_name: str) -> Callable[Concatenate[list[int], P], None]: - def broadcast_func(block_ids: "MayMultiLayerBlockIDs", *args: P.args, + def broadcast_func(block_ids: list[int], *args: P.args, **kwargs: P.kwargs) -> None: for i, block_table in enumerate(self.block_tables): - getattr(block_table, - f_name)(block_ids.get_block_id_of_group(i), *args, - **kwargs) + getattr(block_table, f_name)(block_ids[i], *args, **kwargs) return broadcast_func diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 51f4764b74d8..d741eed1d0b7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -11,7 +11,8 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadataBuilder +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import get_pp_group, graph_capture @@ -25,14 +26,14 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LayerBlockType, LazyLoader, cdiv, - check_use_alibi, is_pin_memory_available) + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec, KVCacheNewTensor, - KVCacheReuseTensor) + KVCacheConfig, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata @@ -220,11 +221,6 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -428,7 +424,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -527,9 +523,10 @@ def _prepare_inputs( block_numbers = ( block_table_cpu.flatten()[block_table_indices].numpy()) block_offsets = positions_np % block_size - np.add(block_numbers * block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -562,21 +559,25 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], + may_multi_layer_unwrapper( + scheduler_output.num_common_prefix_blocks, + kv_cache_group_id), kv_cache_group_spec.kv_cache_spec, self.attn_backends[kv_cache_group_id], ) block_table = may_multi_layer_unwrapper( self.input_batch.block_table, kv_cache_group_id) - attn_metadata = self.attn_metadata_builders[0].build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - block_table=block_table, - ) + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + block_table=block_table, + )) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -586,7 +587,8 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + # TODO: add note for attn_metadata_i + logits_indices = attn_metadata_i.query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -996,7 +998,11 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + + for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: + # TODO: notes for use layer_names[0] + layer_name = kv_cache_group_spec.layer_names[0] + attn_metadata[layer_name].num_input_tokens = num_input_tokens if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision @@ -1647,15 +1653,6 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) - print("reorder_batch", [ - id(builder.reorder_batch.__func__) - for builder in self.attn_metadata_builders - ]) - print("is check", [ - builder.reorder_batch.__func__ - is self.attn_metadata_builders[0].reorder_batch.__func__ - for builder in self.attn_metadata_builders - ]) assert all(builder.reorder_batch.__func__ is self.attn_metadata_builders[0].reorder_batch.__func__ for builder in self.attn_metadata_builders), "TODO" From 273dd443c485d5c8e563c822d7078543f5eb81cf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 09:16:41 -0700 Subject: [PATCH 05/34] run precommit Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 3 +- vllm/v1/attention/backends/mla/common.py | 16 ++- vllm/v1/attention/backends/mla/flashmla.py | 7 +- vllm/v1/core/hybrid_kv_cache_manager.py | 119 +++++++++++++++------ vllm/v1/core/kv_cache_manager.py | 71 ++++++++---- vllm/v1/core/sched/output.py | 9 +- vllm/v1/core/sched/scheduler.py | 18 ++-- vllm/v1/core/specialized_manager.py | 36 ++++--- vllm/v1/kv_cache_interface.py | 65 +---------- vllm/v1/worker/block_table.py | 1 + vllm/v1/worker/gpu_input_batch.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 47 +++++--- vllm/v1/worker/tpu_model_runner.py | 1 + vllm/v1/worker/tpu_worker.py | 2 +- 14 files changed, 232 insertions(+), 170 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 85a5dd77624e..028f71db722e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,6 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) @@ -100,7 +101,7 @@ class FlashAttentionMetadata: class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec): self.runner = runner def reorder_batch(self, input_batch: "InputBatch", diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1437db7e9d48..72bb9e76a1a7 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -203,6 +203,8 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: @@ -342,6 +344,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", + kv_cache_spec: KVCacheSpec, + persistent_block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -375,7 +379,8 @@ def __init__(self, dtype=model_config.dtype, device=runner.device, ) - self.page_size = self.runner.block_size + self.page_size = kv_cache_spec.block_size + self.persistent_block_table = persistent_block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -455,12 +460,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + block_table = (self.persistent_block_table.block_table. + get_device_tensor()[:num_reqs]) query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( device, non_blocking=True) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + slot_mapping = (self.persistent_block_table. + slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long()) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 143bfe35bb5e..8ee588323296 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -16,6 +16,8 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -52,8 +54,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: KVCacheSpec, + persistent_block_table: BlockTable): + super().__init__(runner, kv_cache_spec, persistent_block_table) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index 4ff76b986988..ebfffdddc796 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict +from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocksInterface from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.core.specialized_manager import get_specialized_manager @@ -17,6 +18,20 @@ logger = init_logger(__name__) +@dataclass +class HybridKVCacheBlocks(KVCacheBlocksInterface): + blocks: list[list[KVCacheBlock]] + + def to_block_ids(self) -> list[list[int]]: + return [[blk.block_id for blk in blk_one_layer] + for blk_one_layer in self.blocks] + + def __add__(self, + other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": + assert isinstance(other, HybridKVCacheBlocks) + return HybridKVCacheBlocks(self.blocks + other.blocks) + + class HybridKVCacheManager: """ The HybridKVCacheManager for models with multiple KV cache types @@ -97,11 +112,27 @@ def __init__( self.num_cached_block: dict[str, list[int]] = {} self.prefix_cache_stats = PrefixCacheStats() - usage = KVCacheManager.usage - make_prefix_cache_stats = KVCacheManager.make_prefix_cache_stats + @property + def usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return self.block_pool.get_usage() + + def make_prefix_cache_stats(self) -> PrefixCacheStats: + """Get (and reset) the prefix cache stats. + + Returns: + The current prefix caching stats. + """ + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats def get_computed_blocks( - self, request: Request) -> tuple[list[list[KVCacheBlock]], int]: + self, request: Request) -> tuple[HybridKVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -115,7 +146,10 @@ def get_computed_blocks( """ if not self.enable_caching: # Prefix caching is disabled. - return [[] for _ in range(self.num_kv_cache_groups)], 0 + computed_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(self.num_kv_cache_groups) + ] + return HybridKVCacheBlocks(computed_blocks), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -153,18 +187,18 @@ def get_computed_blocks( self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - return computed_blocks, num_computed_tokens + return HybridKVCacheBlocks(computed_blocks), num_computed_tokens else: # Skip cache hits for prompt logprobs - return [], 0 + return HybridKVCacheBlocks([]), 0 def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + new_computed_blocks: Optional[KVCacheBlocksInterface] = None, num_new_computed_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[HybridKVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -194,9 +228,13 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [ - [] for _ in range(self.num_kv_cache_groups) - ] + if new_computed_blocks is not None: + assert isinstance(new_computed_blocks, HybridKVCacheBlocks) + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = ([ + [] for _ in range(self.num_kv_cache_groups) + ]) req_blocks = self.req_to_blocks[request.request_id] @@ -225,15 +263,15 @@ def allocate_slots( num_computed_tokens + num_tokens, self.specialized_managers[i].block_size) num_new_blocks.append(num_required_blocks_i - len(req_blocks[i]) - - len(new_computed_blocks[i])) + len(new_computed_block_list[i])) total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. num_evictable_computed_blocks = sum( - 1 for blk_one_layer in new_computed_blocks for blk in blk_one_layer - if blk.ref_cnt == 0) + 1 for blk_one_layer in new_computed_block_list + for blk in blk_one_layer if blk.ref_cnt == 0) if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks @@ -241,17 +279,17 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - for blocks in new_computed_blocks: + for blocks in new_computed_block_list: self.block_pool.touch(blocks) else: - assert all(len(blks) == 0 for blks in new_computed_blocks), ( + assert all(len(blks) == 0 for blks in new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. for i in range(self.num_kv_cache_groups): - req_blocks[i].extend(new_computed_blocks[i]) + req_blocks[i].extend(new_computed_block_list[i]) # Start to handle new blocks new_blocks: list[list[KVCacheBlock]] = [] @@ -293,13 +331,13 @@ def allocate_slots( req_blocks[i].extend(new_blocks_this_layer) if not self.enable_caching: - return new_blocks + return HybridKVCacheBlocks(new_blocks) - # Use `new_computed_blocks` for a new request, and `num_cached_block` - # for a running request. + # Use `new_computed_block_list` for a new request, and + # `num_cached_block` for a running request. num_cached_blocks = self.num_cached_block.get( request.request_id, - [len(blocks) for blocks in new_computed_blocks]) + [len(blocks) for blocks in new_computed_block_list]) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. @@ -321,7 +359,7 @@ def allocate_slots( num_cached_blocks[i] = num_full_blocks_after_append self.num_cached_block[request.request_id] = num_cached_blocks - return new_blocks + return HybridKVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -340,7 +378,19 @@ def free(self, request: Request) -> None: self.num_cached_block.pop(request.request_id, None) - reset_prefix_cache = KVCacheManager.reset_prefix_cache + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + if self.block_pool.reset_prefix_cache(): + self.prefix_cache_stats.reset = True + return True + return False def get_num_common_prefix_blocks( self, @@ -394,10 +444,16 @@ def get_num_common_prefix_blocks( num_common_blocks.append(num_common_blocks_i) return num_common_blocks - free_block_hashes = KVCacheManager.free_block_hashes + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) def find_longest_cache_hit( - self, request_id: int, block_hashes: list[list[BlockHashType]] + self, request_id: str, block_hashes: list[list[BlockHashType]] ) -> tuple[list[list[KVCacheBlock]], int]: """Find the longest cache hit for each kv cache group. TODO: add more notes @@ -411,13 +467,14 @@ def find_longest_cache_hit( # Use copy to avoid modifying the original block_hashes block_hashes = [block_hash.copy() for block_hash in block_hashes] - while not max(num_computed_tokens) == min_computed_tokens: + while max(num_computed_tokens) != min_computed_tokens: for i, manager in enumerate(self.specialized_managers): if num_computed_tokens[i] > min_computed_tokens: del block_hashes[i][:min_computed_tokens // manager.block_size] - computed_blocks_group_i = manager.find_longest_cache_hit( - request_id, block_hashes[i], return_const_list=True) + computed_blocks_group_i = ( + manager.find_longest_cache_hit_multiple_calls( + request_id, block_hashes[i])) num_computed_tokens[i] = len(computed_blocks_group_i) * \ manager.block_size @@ -426,9 +483,7 @@ def find_longest_cache_hit( # Get the non-constlist computed blocks computed_blocks = [ - manager.find_longest_cache_hit(request_id, - block_hashes[i], - return_const_list=False) + manager.find_longest_cache_hit(request_id, block_hashes[i]) for i, manager in enumerate(self.specialized_managers) ] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index af06f9919267..e1ba97ff8c95 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union from vllm.logger import init_logger @@ -9,16 +11,44 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) +from vllm.v1.core.sched.output import MayMultiGroupBlockIDs from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus + if TYPE_CHECKING: from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager logger = init_logger(__name__) +class KVCacheBlocksInterface(ABC): + + @abstractmethod + def to_block_ids(self) -> MayMultiGroupBlockIDs: + raise NotImplementedError + + @abstractmethod + def __add__(self, + other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": + raise NotImplementedError + + +@dataclass +class UniformKVCacheBlocks(KVCacheBlocksInterface): + blocks: list[KVCacheBlock] + + def to_block_ids(self) -> list[int]: + return [blk.block_id for blk in self.blocks] + + def __add__(self, + other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": + print(f"other: {other} type: {type(other)}") + assert isinstance(other, UniformKVCacheBlocks) + return UniformKVCacheBlocks(self.blocks + other.blocks) + + class KVCacheManager: def __init__( @@ -103,7 +133,7 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats: return stats def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + self, request: Request) -> tuple[UniformKVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -117,7 +147,7 @@ def get_computed_blocks( """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return UniformKVCacheBlocks([]), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -156,18 +186,18 @@ def get_computed_blocks( # sharing, `num_computed_tokens` is always a multiple of # `block_size`. num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + return UniformKVCacheBlocks(computed_blocks), num_computed_tokens else: # Skip cache hits for prompt logprobs - return [], 0 + return UniformKVCacheBlocks([]), 0 def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + new_computed_blocks: Optional[KVCacheBlocksInterface] = None, num_new_computed_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[UniformKVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -197,7 +227,11 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] + if new_computed_blocks is not None: + assert isinstance(new_computed_blocks, UniformKVCacheBlocks) + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = [] req_blocks = self.req_to_blocks[request.request_id] @@ -218,12 +252,13 @@ def allocate_slots( num_required_blocks = cdiv(num_computed_tokens + num_tokens, self.block_size) num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + len(new_computed_block_list)) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + num_evictable_computed_blocks = sum(1 + for blk in new_computed_block_list if blk.ref_cnt == 0) if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): @@ -232,15 +267,15 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + self.block_pool.touch(new_computed_block_list) else: - assert not new_computed_blocks, ( + assert not new_computed_block_list, ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + req_blocks.extend(new_computed_block_list) # Start to handle new blocks @@ -265,12 +300,12 @@ def allocate_slots( req_blocks.extend(new_blocks) if not self.enable_caching: - return new_blocks + return UniformKVCacheBlocks(new_blocks) - # Use `new_computed_blocks` for a new request, and `num_cached_block` - # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) + # Use `new_computed_block_list` for a new request, and + # `num_cached_block` for a running request. + num_cached_blocks = (self.num_cached_block.get( + request.request_id, len(new_computed_block_list))) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. @@ -289,7 +324,7 @@ def allocate_slots( self.num_cached_block[ request.request_id] = num_full_blocks_after_append - return new_blocks + return UniformKVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index c4aa691a7138..fc42cd283769 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -3,9 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional - -from vllm.v1.kv_cache_interface import MayMultiGroupBlockIDs +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: import numpy as np @@ -16,6 +14,9 @@ from vllm.sampling_params import SamplingParams from vllm.v1.request import Request +MayMultiGroupBlockIDs = Union[list[int], list[list[int]]] +MayMultiGroupInt = Union[int, list[int]] + @dataclass class NewRequestData: @@ -108,7 +109,7 @@ class SchedulerOutput: scheduled_encoder_inputs: dict[str, list[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. - num_common_prefix_blocks: int + num_common_prefix_blocks: MayMultiGroupInt # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 794dfd01ba66..69247c99035b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -14,13 +14,13 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import init_kv_cache_manager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, +from vllm.v1.core.sched.output import (CachedRequestData, + MayMultiGroupBlockIDs, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) -from vllm.v1.kv_cache_interface import (BlockIDGenerator, KVCacheConfig, - MayMultiGroupBlockIDs) +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -70,8 +70,6 @@ def __init__( caching_hash_algo=self.cache_config.prefix_caching_hash_algo, log_stats=self.log_stats) self.block_size = self.cache_config.block_size - BlockIDGenerator.num_kv_cache_groups = len( - self.kv_cache_config.kv_cache_groups) # req_id -> Request self.requests: dict[str, Request] = {} @@ -225,9 +223,8 @@ def schedule(self) -> SchedulerOutput: # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[ - request.request_id] = BlockIDGenerator.from_kv_cache_blocks( - new_blocks) + req_to_new_block_ids[request.request_id] = ( + new_blocks.to_block_ids()) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -350,8 +347,7 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = ( - BlockIDGenerator.from_kv_cache_blocks(computed_blocks) + - BlockIDGenerator.from_kv_cache_blocks(new_blocks)) + computed_blocks + new_blocks).to_block_ids() num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -384,7 +380,7 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. if len(self.kv_cache_config.kv_cache_groups) > 1: - num_common_prefix_blocks = [0] * len( + num_common_prefix_blocks: Union[int, list[int]] = [0] * len( self.kv_cache_config.kv_cache_groups) else: num_common_prefix_blocks = 0 diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index b7e36f776cf8..97a5d2fa88f0 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import namedtuple -from typing import Optional, Union +from typing import Optional from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool @@ -34,14 +33,13 @@ def __init__( self.block_pool = block_pool # for caching the intermediate states between multiple calls of # find_longest_cache_hit - self.req_cached_blocks: dict[int, list[KVCacheBlock]] = {} + self.req_cached_blocks: dict[str, list[KVCacheBlock]] = {} def find_longest_cache_hit( self, - request_id: int, + request_id: str, block_hashes: list[BlockHashType], - return_const_list: bool = False, - ) -> Union[list[KVCacheBlock], ConstantList[KVCacheBlock]]: + ) -> list[KVCacheBlock]: """ Find the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. @@ -58,13 +56,17 @@ def find_longest_cache_hit( req_cached_blocks = self._find_longest_cache_hit( block_hashes, req_cached_blocks) - if return_const_list: - # TODO: add comment - self.req_cached_blocks[request_id] = req_cached_blocks - return ConstantList(req_cached_blocks) - else: - # TODO: add comment - return req_cached_blocks + return req_cached_blocks + + def find_longest_cache_hit_multiple_calls( + self, + request_id: str, + block_hashes: list[BlockHashType], + ) -> ConstantList[KVCacheBlock]: + req_cached_blocks = self.find_longest_cache_hit( + request_id, block_hashes) + self.req_cached_blocks[request_id] = req_cached_blocks + return ConstantList(req_cached_blocks) @abstractmethod def _find_longest_cache_hit( @@ -116,11 +118,11 @@ def _find_longest_cache_hit( computed_blocks: Optional[list[KVCacheBlock]] ) -> list[KVCacheBlock]: if computed_blocks is None: - computed_blocks: list[KVCacheBlock] = [] + computed_blocks = [] for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. + # block_hashes is a chain of block hashes. If a block hash is + # not in the cached_block_hash_to_id, the following block hashes + # are not computed yet for sure. if cached_block := self.block_pool.get_cached_block( block_hash): computed_blocks.append(cached_block) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 8af1130aa804..610e1b85ec05 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Union, cast, overload, Type +from typing import TYPE_CHECKING import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size + if TYPE_CHECKING: - from vllm.v1.core.kv_cache_utils import KVCacheBlock + pass logger = init_logger(__name__) @@ -185,63 +186,3 @@ class KVCacheConfig: there are 3 groups, each of which represents 10 layers in the model. """ kv_cache_groups: list[KVCacheGroupSpec] - - -@dataclass -class MultiGroupBlockIDs: - # A list of block IDs for each virtual layer - _block_ids: list[list[int]] - - def __init__(self, block_ids: list[list[int]]): - self._block_ids = block_ids - - @classmethod - def from_kv_cache_blocks(cls, kv_cache_blocks: list[list["KVCacheBlock"]]): - return cls( - block_ids=[[blk.block_id for blk in kv_cache_blocks_one_layer] - for kv_cache_blocks_one_layer in kv_cache_blocks]) - - def extend(self, new_block_ids: "MultiGroupBlockIDs") -> None: - for i, block_ids in enumerate(new_block_ids._block_ids): - self._block_ids[i].extend(block_ids) - - def __add__(self, other: "MultiGroupBlockIDs") -> "MultiGroupBlockIDs": - return MultiGroupBlockIDs(block_ids=[ - a + b for a, b in zip(self._block_ids, other._block_ids) - ]) - - def get_block_id_of_group(self, group_id: int) -> list[int]: - return self._block_ids[group_id] - - -MayMultiGroupBlockIDs = Union[MultiGroupBlockIDs, list[int]] -MayMultiGroupInt = Union[int, list[int]] - - -class BlockIDGenerator: - num_kv_cache_groups: int - - @overload - @classmethod - def from_kv_cache_blocks( - cls, kv_cache_blocks: list["KVCacheBlock"]) -> list[int]: - ... - - @overload - @classmethod - def from_kv_cache_blocks( - cls, - kv_cache_blocks: list[list["KVCacheBlock"]]) -> MultiGroupBlockIDs: - ... - - @classmethod - def from_kv_cache_blocks( - cls, kv_cache_blocks: Union[list["KVCacheBlock"], - list[list["KVCacheBlock"]]] - ) -> MayMultiGroupBlockIDs: - if cls.num_kv_cache_groups == 1: - kv_cache_blocks = cast(list["KVCacheBlock"], kv_cache_blocks) - return [blk.block_id for blk in kv_cache_blocks] - else: - kv_cache_blocks = cast(list[list["KVCacheBlock"]], kv_cache_blocks) - return MultiGroupBlockIDs.from_kv_cache_blocks(kv_cache_blocks) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index fd8583972b59..0921ea35d487 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -104,6 +104,7 @@ def get_numpy_array(self) -> np.ndarray: class MultiLayerBlockTable: move_row: Callable[P, None] + swap_row: Callable[P, None] commit: Callable[P, None] clear: Callable[P, None] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6b3e356fa591..4564cd917f8c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,11 +11,12 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values +from vllm.v1.core.sched.output import MayMultiGroupBlockIDs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import BlockTable, initialize_block_table +from vllm.v1.worker.block_table import initialize_block_table _SAMPLING_EPS = 1e-5 @@ -31,7 +32,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[int] + block_ids: MayMultiGroupBlockIDs num_computed_tokens: int output_token_ids: list[int] @@ -258,7 +259,7 @@ def add_request( self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(request.block_ids, req_index) + self.block_table.add_row(request.block_ids, req_index) # type: ignore sampling_params = request.sampling_params if sampling_params.sampling_type == SamplingType.GREEDY: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d741eed1d0b7..1d8eba142eb2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Type, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import numpy as np import torch @@ -83,7 +83,7 @@ def __init__( model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config - parallel_config = self.parallel_config + self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -120,8 +120,8 @@ def __init__( # init in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] self.kv_cache_config = cast(KVCacheConfig, None) - self.attn_backends: list[Type[AttentionBackend]] = [] - self.attn_metadata_builders: list[Type[AttentionMetadataBuilder]] = [] + self.attn_backends: list[type[AttentionBackend]] = [] + self.attn_metadata_builders: list[type[AttentionMetadataBuilder]] = [] # Persistent batch self.input_batch = cast(InputBatch, None) @@ -360,7 +360,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) + if len(self.kv_cache_config.kv_cache_groups) == 1: + block_ids = cast(list[int], req_state.block_ids) + new_block_ids = cast(list[int], req_data.new_block_ids) + block_ids.extend(new_block_ids) + else: + hybrid_block_ids = cast(list[list[int]], + req_state.block_ids) + new_hybrid_block_ids = cast(list[list[int]], + req_data.new_block_ids) + for i in range(len(self.kv_cache_config.kv_cache_groups)): + hybrid_block_ids[i].extend(new_hybrid_block_ids[i]) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -377,8 +387,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + self.input_batch.block_table.append_row( + req_data.new_block_ids, # type: ignore + req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) @@ -511,8 +522,9 @@ def _prepare_inputs( # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + positions_np // block_size) @@ -688,6 +700,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or (isinstance(kv_cache_spec, FullAttentionSpec) and kv_cache_spec.compute_as_sliding_window)) + assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_backend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, @@ -1624,7 +1637,8 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # TODO: docstring assert len(self.attn_backends) == 0 and len( self.attn_metadata_builders) == 0, "already initialized" - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec if not isinstance(kv_cache_spec, AttentionSpec): raise NotImplementedError( @@ -1641,15 +1655,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: if attn_backend_i is None: error_msg = ( f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.kv_cache_dtype=}, {kv_cache_spec.block_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " f"{self.model_config.is_attention_free=}, " f"{kv_cache_spec.use_mla=}") logger.error(error_msg) raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner." - ) + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + if isinstance(self.input_batch.block_table, BlockTable): + block_table = self.input_batch.block_table + else: + block_table = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self)) + weakref.proxy(self), kv_cache_spec, block_table) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c2edbaf351d0..38ffd3b1dc7e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# type: ignore import bisect import time from typing import TYPE_CHECKING, Optional, cast diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9add8cee02e5..3e080592b20c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -22,7 +22,7 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner # type: ignore logger = init_logger(__name__) From 0bfec8d2bae5a7f03bf8d35c28be0fd77e3a4b0c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 09:29:19 -0700 Subject: [PATCH 06/34] can run again Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 ++-- vllm/v1/attention/backends/flash_attn.py | 3 ++- vllm/v1/core/hybrid_kv_cache_manager.py | 5 ++++- vllm/v1/core/kv_cache_manager.py | 1 - vllm/v1/worker/gpu_model_runner.py | 4 ++-- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1043654ef978..9026e3345d6b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -206,8 +206,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - # if isinstance(attn_metadata, dict): - # attn_metadata = attn_metadata[self.layer_name] + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 028f71db722e..e8c6426ea64d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -101,7 +101,8 @@ class FlashAttentionMetadata: class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec, + persistent_block_table: BlockTable): self.runner = runner def reorder_batch(self, input_batch: "InputBatch", diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py index ebfffdddc796..6e894539b558 100644 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ b/vllm/v1/core/hybrid_kv_cache_manager.py @@ -29,7 +29,10 @@ def to_block_ids(self) -> list[list[int]]: def __add__(self, other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": assert isinstance(other, HybridKVCacheBlocks) - return HybridKVCacheBlocks(self.blocks + other.blocks) + return HybridKVCacheBlocks([ + self_blocks_i + other_blocks_i + for self_blocks_i, other_blocks_i in zip(self.blocks, other.blocks) + ]) class HybridKVCacheManager: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e1ba97ff8c95..635ecb5e2fd4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -44,7 +44,6 @@ def to_block_ids(self) -> list[int]: def __add__(self, other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": - print(f"other: {other} type: {type(other)}") assert isinstance(other, UniformKVCacheBlocks) return UniformKVCacheBlocks(self.blocks + other.blocks) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1d8eba142eb2..d930ad5ea628 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1684,8 +1684,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: cache size of each layer """ self.kv_cache_config = kv_cache_config - self.initialize_attn_backend(kv_cache_config) - self.initialize_kv_cache_tensors(kv_cache_config) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -1695,6 +1693,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: vocab_size=self.model_config.get_vocab_size(), kv_cache_config=kv_cache_config, ) + self.initialize_attn_backend(kv_cache_config) + self.initialize_kv_cache_tensors(kv_cache_config) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ From df31d7a2c8add7315c5a9274611e423d1c8dc5a5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Apr 2025 23:37:47 -0700 Subject: [PATCH 07/34] quick copy Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 435 ++++++++++++++++++------------- 1 file changed, 261 insertions(+), 174 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 635ecb5e2fd4..6e894539b558 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,54 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_manager import KVCacheBlocksInterface from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) -from vllm.v1.core.sched.output import MayMultiGroupBlockIDs from vllm.v1.core.specialized_manager import get_specialized_manager from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager - logger = init_logger(__name__) -class KVCacheBlocksInterface(ABC): - - @abstractmethod - def to_block_ids(self) -> MayMultiGroupBlockIDs: - raise NotImplementedError - - @abstractmethod - def __add__(self, - other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": - raise NotImplementedError - - @dataclass -class UniformKVCacheBlocks(KVCacheBlocksInterface): - blocks: list[KVCacheBlock] +class HybridKVCacheBlocks(KVCacheBlocksInterface): + blocks: list[list[KVCacheBlock]] - def to_block_ids(self) -> list[int]: - return [blk.block_id for blk in self.blocks] + def to_block_ids(self) -> list[list[int]]: + return [[blk.block_id for blk in blk_one_layer] + for blk_one_layer in self.blocks] def __add__(self, other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": - assert isinstance(other, UniformKVCacheBlocks) - return UniformKVCacheBlocks(self.blocks + other.blocks) + assert isinstance(other, HybridKVCacheBlocks) + return HybridKVCacheBlocks([ + self_blocks_i + other_blocks_i + for self_blocks_i, other_blocks_i in zip(self.blocks, other.blocks) + ]) -class KVCacheManager: +class HybridKVCacheManager: + """ + The HybridKVCacheManager for models with multiple KV cache types + (e.g., Gemma-2) and thus multiple kv cache groups (Refer to class + `KVCacheConfig` for the meaning of kv cache groups). + """ def __init__( self, @@ -59,15 +51,15 @@ def __init__( num_preallocate_tokens: int = 64, log_stats: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size + # TODO: adjust the name for item in one group, list of items in all + # groups, and reduced item for all groups. + self.kv_cache_config = kv_cache_config self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) - + self.max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash # FIXME: make prefix cache stats conditional on log_stats @@ -82,34 +74,45 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all layers, which will result in different + # preallocated tokens for different layers if their block sizes are + # different. + self.num_preallocate_blocks = cdiv( + num_preallocate_tokens, + max(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups)) self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, - ) + self.specialized_managers = [ + get_specialized_manager( + kv_cache_spec=g.kv_cache_spec, + block_pool=self.block_pool, + ) for g in kv_cache_config.kv_cache_groups + ] + + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[ + str, list[list[KVCacheBlock]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + str, list[list[BlockHashType]]] = defaultdict( + lambda: [[] for _ in range(self.num_kv_cache_groups)]) - # {req_id: The number of cached blocks for this given request} + # {req_id: The number of cached blocks for each kv cache group} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: dict[str, int] = {} + self.num_cached_block: dict[str, list[int]] = {} self.prefix_cache_stats = PrefixCacheStats() @property @@ -132,7 +135,7 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats: return stats def get_computed_blocks( - self, request: Request) -> tuple[UniformKVCacheBlocks, int]: + self, request: Request) -> tuple[HybridKVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -141,54 +144,56 @@ def get_computed_blocks( Returns: A tuple containing: - - A list of blocks that are computed for the request. + - A list of blocks that are computed for each kv cache group. - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return UniformKVCacheBlocks([]), 0 + computed_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(self.num_kv_cache_groups) + ] + return HybridKVCacheBlocks(computed_blocks), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) + block_hashes = [ + hash_request_tokens(self.caching_hash_fn, + g.kv_cache_spec.block_size, request, i) + for i, g in enumerate(self.kv_cache_config.kv_cache_groups) + ] self.req_to_block_hashes[request.request_id] = block_hashes self.prefix_cache_stats.requests += 1 if request.sampling_params.prompt_logprobs is None: - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - - computed_blocks = (self.specialized_manager.find_longest_cache_hit( - request.request_id, block_hashes)) - - if last_block_hash is not None: - # Add back the last block hash if it was removed. - block_hashes.append(last_block_hash) + # TODO: Fix last block problem + # if len(block_hashes) * self.block_size == request.num_tokens: + # # When prompt length is divisible by the block size and all + # # blocks are cached, we need to recompute the last token. This + # # have to be achieved by re-computing an entire block because + # # allocate_slots() assumes num_computed_tokens is always a + # # multiple of the block size. To achieve this, remove the last + # # block hash from the block_hashes for find_longest_cache_hit + # # This limitation can potentially be removed in the future to + # # slightly improve the performance. + # last_block_hash = block_hashes.pop() + # else: + # last_block_hash = None + + computed_blocks, num_computed_tokens = self.find_longest_cache_hit( + request.request_id, block_hashes) + + # if last_block_hash is not None: + # # Add back the last block hash if it was removed. + # block_hashes.append(last_block_hash) self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size - return UniformKVCacheBlocks(computed_blocks), num_computed_tokens + return HybridKVCacheBlocks(computed_blocks), num_computed_tokens else: # Skip cache hits for prompt logprobs - return UniformKVCacheBlocks([]), 0 + return HybridKVCacheBlocks([]), 0 def allocate_slots( self, @@ -196,7 +201,7 @@ def allocate_slots( num_tokens: int, new_computed_blocks: Optional[KVCacheBlocksInterface] = None, num_new_computed_tokens: int = 0, - ) -> Optional[UniformKVCacheBlocks]: + ) -> Optional[HybridKVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -227,10 +232,12 @@ def allocate_slots( raise ValueError("num_tokens must be greater than 0") if new_computed_blocks is not None: - assert isinstance(new_computed_blocks, UniformKVCacheBlocks) + assert isinstance(new_computed_blocks, HybridKVCacheBlocks) new_computed_block_list = new_computed_blocks.blocks else: - new_computed_block_list = [] + new_computed_block_list = ([ + [] for _ in range(self.num_kv_cache_groups) + ]) req_blocks = self.req_to_blocks[request.request_id] @@ -240,90 +247,122 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + removed_blocks = [ + manager.remove_skipped_blocks(req_blocks[i], + request.num_computed_tokens) + for i, manager in enumerate(self.specialized_managers) + ] + self._free_blocks(removed_blocks) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + num_new_computed_tokens) - num_required_blocks = cdiv(num_computed_tokens + num_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_block_list)) + + num_new_blocks: list[int] = [] + for i in range(self.num_kv_cache_groups): + num_required_blocks_i = cdiv( + num_computed_tokens + num_tokens, + self.specialized_managers[i].block_size) + num_new_blocks.append(num_required_blocks_i - len(req_blocks[i]) - + len(new_computed_block_list[i])) + total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 - for blk in new_computed_block_list - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks = sum( + 1 for blk_one_layer in new_computed_block_list + for blk in blk_one_layer if blk.ref_cnt == 0) + if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_block_list) + for blocks in new_computed_block_list: + self.block_pool.touch(blocks) else: - assert not new_computed_block_list, ( + assert all(len(blks) == 0 for blks in new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_block_list) + for i in range(self.num_kv_cache_groups): + req_blocks[i].extend(new_computed_block_list[i]) # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + new_blocks: list[list[KVCacheBlock]] = [] + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each layer. + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.block_pool.get_num_free_blocks() - total_num_new_blocks) // + len(self.specialized_managers)) + + for i in range(self.num_kv_cache_groups): + if num_new_blocks[i] <= 0: + # No new block is needed. + new_blocks.append([]) + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_block_to_allocate = min( + num_new_blocks[i] + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + # Don't need self.block_pool.get_num_free_blocks() as in + # KVCacheManager because we already considered it when + # calculating num_preallocate_blocks + self.max_num_blocks_per_req[i] - len(req_blocks[i]), + ) + + assert num_block_to_allocate > 0 + assert num_block_to_allocate <= \ + self.block_pool.get_num_free_blocks() + + # Concatenate the computed block IDs and the new block IDs. + new_blocks_this_layer = self.block_pool.get_new_blocks( + num_block_to_allocate) + new_blocks.append(new_blocks_this_layer) + req_blocks[i].extend(new_blocks_this_layer) if not self.enable_caching: - return UniformKVCacheBlocks(new_blocks) + return HybridKVCacheBlocks(new_blocks) # Use `new_computed_block_list` for a new request, and # `num_cached_block` for a running request. - num_cached_blocks = (self.num_cached_block.get( - request.request_id, len(new_computed_block_list))) + num_cached_blocks = self.num_cached_block.get( + request.request_id, + [len(blocks) for blocks in new_computed_block_list]) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) - - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append - return UniformKVCacheBlocks(new_blocks) + for i in range(self.num_kv_cache_groups): + num_full_blocks_after_append = ( + num_computed_tokens + num_tokens - len(request.spec_token_ids) + ) // self.specialized_managers[i].block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks[i], + block_hashes=self.req_to_block_hashes[request.request_id][i], + num_cached_blocks=num_cached_blocks[i], + num_full_blocks=num_full_blocks_after_append, + block_size=self.specialized_managers[i].block_size, + hash_fn=self.caching_hash_fn, + kv_cache_group_id=i, + ) + num_cached_blocks[i] = num_full_blocks_after_append + + self.num_cached_block[request.request_id] = num_cached_blocks + return HybridKVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -334,14 +373,12 @@ def free(self, request: Request) -> None: request: The request to free the blocks. """ # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) + blocks = self.req_to_blocks.pop(request.request_id, None) + if blocks is not None: + # Reverse the blocks so that the tail blocks can have higher + # eviction priority. + self._free_blocks([list(reversed(blks)) for blks in blocks]) - self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: @@ -362,7 +399,7 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> list[int]: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. @@ -394,16 +431,20 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. + list[int]: The number of common prefix blocks for each kv cache + group. """ assert request.status == RequestStatus.RUNNING blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break + num_common_blocks = [] + for i in range(self.num_kv_cache_groups): + num_common_blocks_i = 0 + for block in blocks[i]: + if block.ref_cnt == num_running_requests: + num_common_blocks_i += 1 + else: + break + num_common_blocks.append(num_common_blocks_i) return num_common_blocks def free_block_hashes(self, request: Request) -> None: @@ -414,31 +455,77 @@ def free_block_hashes(self, request: Request) -> None: """ self.req_to_block_hashes.pop(request.request_id, None) + def find_longest_cache_hit( + self, request_id: str, block_hashes: list[list[BlockHashType]] + ) -> tuple[list[list[KVCacheBlock]], int]: + """Find the longest cache hit for each kv cache group. + TODO: add more notes + """ + # TODO: accelerate by make full attention the first layer + # TODO: add note for the two magic number + num_computed_tokens = [self.max_model_len + 100] * len( + self.specialized_managers) + min_computed_tokens = self.max_model_len + + # Use copy to avoid modifying the original block_hashes + block_hashes = [block_hash.copy() for block_hash in block_hashes] + + while max(num_computed_tokens) != min_computed_tokens: + for i, manager in enumerate(self.specialized_managers): + if num_computed_tokens[i] > min_computed_tokens: + del block_hashes[i][:min_computed_tokens // + manager.block_size] + computed_blocks_group_i = ( + manager.find_longest_cache_hit_multiple_calls( + request_id, block_hashes[i])) + + num_computed_tokens[i] = len(computed_blocks_group_i) * \ + manager.block_size + min_computed_tokens = min(min_computed_tokens, + num_computed_tokens[i]) + + # Get the non-constlist computed blocks + computed_blocks = [ + manager.find_longest_cache_hit(request_id, block_hashes[i]) + for i, manager in enumerate(self.specialized_managers) + ] + + assert all( + len(block) * manager.block_size == min_computed_tokens for block, + manager in zip(computed_blocks, self.specialized_managers)) + + return computed_blocks, min_computed_tokens + + def _merge_blocks_by_eviction_order( + self, blocks: list[list[KVCacheBlock]]) -> list[KVCacheBlock]: + """ + Merge the blocks of different layers to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. + Args: + blocks: the blocks of each virtual layer, ordered by eviction + priority. + Returns: + A list of KVCacheBlocks sorted by eviction order. + """ + + if self.enable_caching: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # each layer. We can investigate more advanced strategies + # in the future. + ordered_blocks = [] + max_len = max(len(blocks_one_layer) for blocks_one_layer in blocks) + for i in range(max_len): + for blocks_one_layer in blocks: + if i < len(blocks_one_layer): + ordered_blocks.append(blocks_one_layer[i]) + else: + ordered_blocks = [] + for blocks_one_layer in blocks: + ordered_blocks.extend(blocks_one_layer) -def init_kv_cache_manager( - kv_cache_config: KVCacheConfig, - max_model_len: int, - enable_caching: bool = True, - caching_hash_algo: str = "builtin", - num_preallocate_tokens: int = 64, - log_stats: bool = False, -) -> Union[KVCacheManager, "HybridKVCacheManager"]: - from vllm.v1.core.hybrid_kv_cache_manager import HybridKVCacheManager - if len(kv_cache_config.kv_cache_groups) > 1: - return HybridKVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=max_model_len, - enable_caching=enable_caching, - caching_hash_algo=caching_hash_algo, - num_preallocate_tokens=num_preallocate_tokens, - log_stats=log_stats, - ) - else: - return KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=max_model_len, - enable_caching=enable_caching, - caching_hash_algo=caching_hash_algo, - num_preallocate_tokens=num_preallocate_tokens, - log_stats=log_stats, - ) + return ordered_blocks + + def _free_blocks(self, blocks: list[list[KVCacheBlock]]) -> None: + ordered_blocks = self._merge_blocks_by_eviction_order(blocks) + self.block_pool.free_blocks(ordered_blocks) From 6aee98d59725e4215211a01319279904791e6983 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 23 Apr 2025 06:50:57 -0700 Subject: [PATCH 08/34] a runable version Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 9 +- vllm/v1/attention/backends/flashinfer.py | 1 + vllm/v1/attention/backends/mla/common.py | 1 + vllm/v1/attention/backends/mla/flashmla.py | 2 +- vllm/v1/attention/backends/mla/triton_mla.py | 2 +- vllm/v1/core/hybrid_kv_cache_manager.py | 531 ------------------- vllm/v1/core/kv_cache_manager.py | 107 ++-- vllm/v1/core/sched/output.py | 15 +- vllm/v1/core/sched/scheduler.py | 19 +- vllm/v1/worker/block_table.py | 14 +- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 61 +-- 12 files changed, 87 insertions(+), 678 deletions(-) delete mode 100644 vllm/v1/core/hybrid_kv_cache_manager.py diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6857b7ce22e1..5a266a133f26 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -291,20 +291,23 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec, self.num_heads_kv = model_config.get_num_kv_heads( runner.parallel_config) self.headdim = model_config.get_head_size() - self.page_size = self.runner.block_size + self.page_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.persistent_block_table = persistent_block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, block_table: BlockTable): + common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] query_start_loc = query_start_loc_cpu.to(self.runner.device, non_blocking=True) seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + block_table = self.persistent_block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() @@ -335,7 +338,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], block_table, - self.runner.block_size, + self.kv_cache_spec.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( self.runner.device, non_blocking=True) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..9e2440c5d477 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# type: ignore """Attention layer with FlashInfer.""" from __future__ import annotations diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index bc47384831b4..71b2558faf0f 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# type: ignore """ This file implements common components for MLA implementations. diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 8ee588323296..2260761abed8 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - +# type: ignore from dataclasses import dataclass from typing import Any, Optional diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 8e7e4f10b81b..83d2116aa81d 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - +# type: ignore from typing import Any, Optional import torch diff --git a/vllm/v1/core/hybrid_kv_cache_manager.py b/vllm/v1/core/hybrid_kv_cache_manager.py deleted file mode 100644 index 6e894539b558..000000000000 --- a/vllm/v1/core/hybrid_kv_cache_manager.py +++ /dev/null @@ -1,531 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from collections import defaultdict -from dataclasses import dataclass -from typing import Optional - -from vllm.logger import init_logger -from vllm.utils import cdiv, sha256 -from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_manager import KVCacheBlocksInterface -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import PrefixCacheStats -from vllm.v1.request import Request, RequestStatus - -logger = init_logger(__name__) - - -@dataclass -class HybridKVCacheBlocks(KVCacheBlocksInterface): - blocks: list[list[KVCacheBlock]] - - def to_block_ids(self) -> list[list[int]]: - return [[blk.block_id for blk in blk_one_layer] - for blk_one_layer in self.blocks] - - def __add__(self, - other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": - assert isinstance(other, HybridKVCacheBlocks) - return HybridKVCacheBlocks([ - self_blocks_i + other_blocks_i - for self_blocks_i, other_blocks_i in zip(self.blocks, other.blocks) - ]) - - -class HybridKVCacheManager: - """ - The HybridKVCacheManager for models with multiple KV cache types - (e.g., Gemma-2) and thus multiple kv cache groups (Refer to class - `KVCacheConfig` for the meaning of kv cache groups). - """ - - def __init__( - self, - kv_cache_config: KVCacheConfig, - max_model_len: int, - enable_caching: bool = True, - caching_hash_algo: str = "builtin", - num_preallocate_tokens: int = 64, - log_stats: bool = False, - ) -> None: - # TODO: adjust the name for item in one group, list of items in all - # groups, and reduced item for all groups. - self.kv_cache_config = kv_cache_config - self.num_gpu_blocks = kv_cache_config.num_blocks - self.max_model_len = max_model_len - self.max_num_blocks_per_req = [ - cdiv(max_model_len, g.kv_cache_spec.block_size) - for g in kv_cache_config.kv_cache_groups - ] - self.enable_caching = enable_caching - self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash - # FIXME: make prefix cache stats conditional on log_stats - self.log_stats = log_stats - # NOTE(woosuk): To avoid frequent block allocation, we preallocate some - # blocks for each request. For example, when a request reaches the end - # of its block table, we preallocate N blocks in advance. This way, we - # reduce the overhead of updating free_block_ids and ref_cnts for each - # request every step (at the cost of some memory waste). - # NOTE(woosuk): This is different from the "lookahead" slots since this - # does not guarantee that the request always has N empty blocks. After - # the request gets N empty blocks, it starts to use the blocks without - # further allocation. When it uses up all the N empty blocks, it gets - # N new empty blocks. - # NOTE(Chen): For simplicity, we keep the number of preallocated blocks - # the same for all layers, which will result in different - # preallocated tokens for different layers if their block sizes are - # different. - self.num_preallocate_blocks = cdiv( - num_preallocate_tokens, - max(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups)) - - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - - self.specialized_managers = [ - get_specialized_manager( - kv_cache_spec=g.kv_cache_spec, - block_pool=self.block_pool, - ) for g in kv_cache_config.kv_cache_groups - ] - - self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) - - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[ - str, list[list[KVCacheBlock]]] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) - - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[list[BlockHashType]]] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) - - # {req_id: The number of cached blocks for each kv cache group} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, list[int]] = {} - self.prefix_cache_stats = PrefixCacheStats() - - @property - def usage(self) -> float: - """Get the KV cache usage. - - Returns: - The KV cache usage (between 0.0 and 1.0). - """ - return self.block_pool.get_usage() - - def make_prefix_cache_stats(self) -> PrefixCacheStats: - """Get (and reset) the prefix cache stats. - - Returns: - The current prefix caching stats. - """ - stats = self.prefix_cache_stats - self.prefix_cache_stats = PrefixCacheStats() - return stats - - def get_computed_blocks( - self, request: Request) -> tuple[HybridKVCacheBlocks, int]: - """Get the computed (cached) blocks for the request. - Note that the computed blocks must be full. - - Args: - request: The request to get the computed blocks. - - Returns: - A tuple containing: - - A list of blocks that are computed for each kv cache group. - - The number of computed tokens. - """ - if not self.enable_caching: - # Prefix caching is disabled. - computed_blocks: list[list[KVCacheBlock]] = [ - [] for _ in range(self.num_kv_cache_groups) - ] - return HybridKVCacheBlocks(computed_blocks), 0 - - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - block_hashes = [ - hash_request_tokens(self.caching_hash_fn, - g.kv_cache_spec.block_size, request, i) - for i, g in enumerate(self.kv_cache_config.kv_cache_groups) - ] - self.req_to_block_hashes[request.request_id] = block_hashes - - self.prefix_cache_stats.requests += 1 - if request.sampling_params.prompt_logprobs is None: - # TODO: Fix last block problem - # if len(block_hashes) * self.block_size == request.num_tokens: - # # When prompt length is divisible by the block size and all - # # blocks are cached, we need to recompute the last token. This - # # have to be achieved by re-computing an entire block because - # # allocate_slots() assumes num_computed_tokens is always a - # # multiple of the block size. To achieve this, remove the last - # # block hash from the block_hashes for find_longest_cache_hit - # # This limitation can potentially be removed in the future to - # # slightly improve the performance. - # last_block_hash = block_hashes.pop() - # else: - # last_block_hash = None - - computed_blocks, num_computed_tokens = self.find_longest_cache_hit( - request.request_id, block_hashes) - - # if last_block_hash is not None: - # # Add back the last block hash if it was removed. - # block_hashes.append(last_block_hash) - - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) - return HybridKVCacheBlocks(computed_blocks), num_computed_tokens - else: - # Skip cache hits for prompt logprobs - return HybridKVCacheBlocks([]), 0 - - def allocate_slots( - self, - request: Request, - num_tokens: int, - new_computed_blocks: Optional[KVCacheBlocksInterface] = None, - num_new_computed_tokens: int = 0, - ) -> Optional[HybridKVCacheBlocks]: - """Add slots for a request with new tokens to append. - - Args: - request: The request to allocate slots. - num_tokens: The number of tokens to allocate. Note that this does - not include the tokens that have already been computed. - new_computed_blocks: A list of new computed blocks just hitting the - prefix caching. - num_new_computed_tokens: The number of new computed tokens in the - new_computed_blocks. - - Blocks layout: - ----------------------------------------------------------------------- - | < computed > | < new computed > | < new > | < pre-allocated > | - ----------------------------------------------------------------------- - | < required > | - -------------------------------------------------- - | < full > | - ------------------------------------------------ - | | - -------------- - The following *_blocks are illustrated in this layout. - - Returns: - A list of new allocated blocks. - """ - if num_tokens == 0: - raise ValueError("num_tokens must be greater than 0") - - if new_computed_blocks is not None: - assert isinstance(new_computed_blocks, HybridKVCacheBlocks) - new_computed_block_list = new_computed_blocks.blocks - else: - new_computed_block_list = ([ - [] for _ in range(self.num_kv_cache_groups) - ]) - - req_blocks = self.req_to_blocks[request.request_id] - - # Free the blocks that are skipped during the attention computation - # (e.g., tokens outside the sliding window). - # We can do this even if we cannot schedule this request due to - # insufficient free blocks. - # Should call this function before allocating new blocks to reduce - # the number of evicted blocks. - removed_blocks = [ - manager.remove_skipped_blocks(req_blocks[i], - request.num_computed_tokens) - for i, manager in enumerate(self.specialized_managers) - ] - self._free_blocks(removed_blocks) - - # The number of computed tokens is the number of computed tokens plus - # the new prefix caching hits - - num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) - - num_new_blocks: list[int] = [] - for i in range(self.num_kv_cache_groups): - num_required_blocks_i = cdiv( - num_computed_tokens + num_tokens, - self.specialized_managers[i].block_size) - num_new_blocks.append(num_required_blocks_i - len(req_blocks[i]) - - len(new_computed_block_list[i])) - total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum( - 1 for blk_one_layer in new_computed_block_list - for blk in blk_one_layer if blk.ref_cnt == 0) - if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): - # Cannot allocate new blocks - return None - - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - for blocks in new_computed_block_list: - self.block_pool.touch(blocks) - else: - assert all(len(blks) == 0 for blks in new_computed_block_list), ( - "Computed blocks should be empty when " - "prefix caching is disabled") - - # Append the new computed blocks to the request blocks until now to - # avoid the case where the new blocks cannot be allocated. - for i in range(self.num_kv_cache_groups): - req_blocks[i].extend(new_computed_block_list[i]) - - # Start to handle new blocks - new_blocks: list[list[KVCacheBlock]] = [] - # Truncate the number of pre-allocated blocks to ensure that we can - # have at least `num_new_blocks` free blocks for each layer. - num_preallocate_blocks = min( - self.num_preallocate_blocks, - (self.block_pool.get_num_free_blocks() - total_num_new_blocks) // - len(self.specialized_managers)) - - for i in range(self.num_kv_cache_groups): - if num_new_blocks[i] <= 0: - # No new block is needed. - new_blocks.append([]) - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_block_to_allocate = min( - num_new_blocks[i] + num_preallocate_blocks, - # Should not exceed the maximum number of blocks per request - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - # Don't need self.block_pool.get_num_free_blocks() as in - # KVCacheManager because we already considered it when - # calculating num_preallocate_blocks - self.max_num_blocks_per_req[i] - len(req_blocks[i]), - ) - - assert num_block_to_allocate > 0 - assert num_block_to_allocate <= \ - self.block_pool.get_num_free_blocks() - - # Concatenate the computed block IDs and the new block IDs. - new_blocks_this_layer = self.block_pool.get_new_blocks( - num_block_to_allocate) - new_blocks.append(new_blocks_this_layer) - req_blocks[i].extend(new_blocks_this_layer) - - if not self.enable_caching: - return HybridKVCacheBlocks(new_blocks) - - # Use `new_computed_block_list` for a new request, and - # `num_cached_block` for a running request. - num_cached_blocks = self.num_cached_block.get( - request.request_id, - [len(blocks) for blocks in new_computed_block_list]) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. - for i in range(self.num_kv_cache_groups): - num_full_blocks_after_append = ( - num_computed_tokens + num_tokens - len(request.spec_token_ids) - ) // self.specialized_managers[i].block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks[i], - block_hashes=self.req_to_block_hashes[request.request_id][i], - num_cached_blocks=num_cached_blocks[i], - num_full_blocks=num_full_blocks_after_append, - block_size=self.specialized_managers[i].block_size, - hash_fn=self.caching_hash_fn, - kv_cache_group_id=i, - ) - num_cached_blocks[i] = num_full_blocks_after_append - - self.num_cached_block[request.request_id] = num_cached_blocks - return HybridKVCacheBlocks(new_blocks) - - def free(self, request: Request) -> None: - """Free the blocks allocated for the request. - When caching is enabled, we free the blocks in reverse order so that - the tail blocks are evicted first. - - Args: - request: The request to free the blocks. - """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, None) - if blocks is not None: - # Reverse the blocks so that the tail blocks can have higher - # eviction priority. - self._free_blocks([list(reversed(blks)) for blks in blocks]) - - self.num_cached_block.pop(request.request_id, None) - - def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - if self.block_pool.reset_prefix_cache(): - self.prefix_cache_stats.reset = True - return True - return False - - def get_num_common_prefix_blocks( - self, - request: Request, - num_running_requests: int, - ) -> list[int]: - """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state. - - The function determines this by selecting any request and iterating - through its blocks. A block is considered a common prefix block if its - `ref_cnt` equals the total number of requests in the RUNNING state. - - NOTE(woosuk): The number of requests in the RUNNING state is **greater - than or equal to** the number of requests scheduled in the current step. - This is because the RUNNING state only indicates that: - 1. The request has not yet finished, and - 2. The request holds its blocks unfreed. - - While all scheduled requests must be in the RUNNING state, the inverse - is not necessarily true. There may be RUNNING requests that are not - scheduled in the current step. - - This can result in an edge case where the number of common prefix blocks - is 0, even though all scheduled requests share a common prefix. This - occurs because there may be unscheduled RUNNING requests that do not - share the common prefix. Currently, this case cannot be easily detected, - so the function returns 0 in such cases. - - Args: - request: Any request in the RUNNING state, used to identify the - common prefix blocks. - num_running_requests: The total number of requests in the RUNNING - state. This can be different from the number of scheduled - requests in the current step. - - Returns: - list[int]: The number of common prefix blocks for each kv cache - group. - """ - assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = [] - for i in range(self.num_kv_cache_groups): - num_common_blocks_i = 0 - for block in blocks[i]: - if block.ref_cnt == num_running_requests: - num_common_blocks_i += 1 - else: - break - num_common_blocks.append(num_common_blocks_i) - return num_common_blocks - - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - - def find_longest_cache_hit( - self, request_id: str, block_hashes: list[list[BlockHashType]] - ) -> tuple[list[list[KVCacheBlock]], int]: - """Find the longest cache hit for each kv cache group. - TODO: add more notes - """ - # TODO: accelerate by make full attention the first layer - # TODO: add note for the two magic number - num_computed_tokens = [self.max_model_len + 100] * len( - self.specialized_managers) - min_computed_tokens = self.max_model_len - - # Use copy to avoid modifying the original block_hashes - block_hashes = [block_hash.copy() for block_hash in block_hashes] - - while max(num_computed_tokens) != min_computed_tokens: - for i, manager in enumerate(self.specialized_managers): - if num_computed_tokens[i] > min_computed_tokens: - del block_hashes[i][:min_computed_tokens // - manager.block_size] - computed_blocks_group_i = ( - manager.find_longest_cache_hit_multiple_calls( - request_id, block_hashes[i])) - - num_computed_tokens[i] = len(computed_blocks_group_i) * \ - manager.block_size - min_computed_tokens = min(min_computed_tokens, - num_computed_tokens[i]) - - # Get the non-constlist computed blocks - computed_blocks = [ - manager.find_longest_cache_hit(request_id, block_hashes[i]) - for i, manager in enumerate(self.specialized_managers) - ] - - assert all( - len(block) * manager.block_size == min_computed_tokens for block, - manager in zip(computed_blocks, self.specialized_managers)) - - return computed_blocks, min_computed_tokens - - def _merge_blocks_by_eviction_order( - self, blocks: list[list[KVCacheBlock]]) -> list[KVCacheBlock]: - """ - Merge the blocks of different layers to one list. The returned blocks - are sorted by eviction order, with the first block having the highest - eviction priority. - Args: - blocks: the blocks of each virtual layer, ordered by eviction - priority. - Returns: - A list of KVCacheBlocks sorted by eviction order. - """ - - if self.enable_caching: - # NOTE (Chen): A simple strategy that interleaves the blocks of - # each layer. We can investigate more advanced strategies - # in the future. - ordered_blocks = [] - max_len = max(len(blocks_one_layer) for blocks_one_layer in blocks) - for i in range(max_len): - for blocks_one_layer in blocks: - if i < len(blocks_one_layer): - ordered_blocks.append(blocks_one_layer[i]) - else: - ordered_blocks = [] - for blocks_one_layer in blocks: - ordered_blocks.extend(blocks_one_layer) - - return ordered_blocks - - def _free_blocks(self, blocks: list[list[KVCacheBlock]]) -> None: - ordered_blocks = self._merge_blocks_by_eviction_order(blocks) - self.block_pool.free_blocks(ordered_blocks) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 114aed0f7fbf..c501cec92d82 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -7,7 +7,6 @@ from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_manager import KVCacheBlocksInterface from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.core.specialized_manager import get_specialized_manager @@ -19,25 +18,23 @@ @dataclass -class HybridKVCacheBlocks(KVCacheBlocksInterface): +class KVCacheBlocks: blocks: list[list[KVCacheBlock]] def to_block_ids(self) -> list[list[int]]: return [[blk.block_id for blk in blk_one_layer] for blk_one_layer in self.blocks] - def __add__(self, - other: "KVCacheBlocksInterface") -> "KVCacheBlocksInterface": - assert isinstance(other, HybridKVCacheBlocks) - return HybridKVCacheBlocks([ + def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": + return KVCacheBlocks([ self_blocks_i + other_blocks_i for self_blocks_i, other_blocks_i in zip(self.blocks, other.blocks) ]) -class HybridKVCacheManager: +class KVCacheManager: """ - The HybridKVCacheManager for models with multiple KV cache types + The KVCacheManager for models with multiple KV cache types (e.g., Gemma-2) and thus multiple kv cache groups (Refer to class `KVCacheConfig` for the meaning of kv cache groups). """ @@ -93,7 +90,7 @@ def __init__( # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: dict[str, int] = {} + self.num_cached_block: dict[str, list[int]] = {} @property def usage(self) -> float: @@ -116,8 +113,8 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks( - self, request: Request) -> tuple[HybridKVCacheBlocks, int]: + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -134,7 +131,7 @@ def get_computed_blocks( computed_blocks: list[list[KVCacheBlock]] = [ [] for _ in range(self.num_kv_cache_groups) ] - return HybridKVCacheBlocks(computed_blocks), 0 + return KVCacheBlocks(computed_blocks), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -152,10 +149,11 @@ def get_computed_blocks( self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: - return [], 0 + return KVCacheBlocks([[] + for _ in range(self.num_kv_cache_groups)]), 0 # TODO: Fix last block problem - # if len(block_hashes) * self.block_size == request.num_tokens: + # if len(block_hashes) * self.block_size == request.num_tokens: # # When prompt length is divisible by the block size and all # # blocks are cached, we need to recompute the last token. This # # have to be achieved by re-computing an entire block because @@ -179,16 +177,16 @@ def get_computed_blocks( self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - return HybridKVCacheBlocks(computed_blocks), num_computed_tokens + return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[KVCacheBlocksInterface] = None, - num_lookahead_tokens: int = 0,, + new_computed_blocks: Optional[KVCacheBlocks] = None, num_new_computed_tokens: int = 0, - ) -> Optional[HybridKVCacheBlocks]: + num_lookahead_tokens: int = 0, + ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -223,7 +221,6 @@ def allocate_slots( raise ValueError("num_tokens must be greater than 0") if new_computed_blocks is not None: - assert isinstance(new_computed_blocks, HybridKVCacheBlocks) new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = ([ @@ -285,32 +282,32 @@ def allocate_slots( for i in range(self.num_kv_cache_groups): req_blocks[i].extend(new_computed_block_list[i]) + new_blocks: list[list[KVCacheBlock]] = [] # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool. - # TODO: to group - num_new_blocks = min( - num_new_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 + for i in range(self.num_kv_cache_groups): + if num_new_blocks[i] <= 0: + # No new block is needed. + new_blocks.append([]) + else: + # Get new blocks from the free block pool. + num_new_blocks_i = min( + num_new_blocks[i], + # Should not exceed the maximum number of blocks per + # request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + self.max_num_blocks_per_req[i] - len(req_blocks[i]), + ) + assert num_new_blocks_i > 0 # Concatenate the computed block IDs and the new block IDs. new_blocks_this_layer = self.block_pool.get_new_blocks( - num_block_to_allocate) + num_new_blocks_i) new_blocks.append(new_blocks_this_layer) req_blocks[i].extend(new_blocks_this_layer) if not self.enable_caching: - return HybridKVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks) # Use `new_computed_block_list` for a new request, and # `num_cached_block` for a running request. @@ -338,7 +335,7 @@ def allocate_slots( num_cached_blocks[i] = num_full_blocks_after_append self.num_cached_block[request.request_id] = num_cached_blocks - return HybridKVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -474,36 +471,8 @@ def find_longest_cache_hit( return computed_blocks, min_computed_tokens - def _merge_blocks_by_eviction_order( - self, blocks: list[list[KVCacheBlock]]) -> list[KVCacheBlock]: - """ - Merge the blocks of different layers to one list. The returned blocks - are sorted by eviction order, with the first block having the highest - eviction priority. - Args: - blocks: the blocks of each virtual layer, ordered by eviction - priority. - Returns: - A list of KVCacheBlocks sorted by eviction order. - """ - - if self.enable_caching: - # NOTE (Chen): A simple strategy that interleaves the blocks of - # each layer. We can investigate more advanced strategies - # in the future. - ordered_blocks = [] - max_len = max(len(blocks_one_layer) for blocks_one_layer in blocks) - for i in range(max_len): - for blocks_one_layer in blocks: - if i < len(blocks_one_layer): - ordered_blocks.append(blocks_one_layer[i]) - else: - ordered_blocks = [] - for blocks_one_layer in blocks: - ordered_blocks.extend(blocks_one_layer) - - return ordered_blocks - def _free_blocks(self, blocks: list[list[KVCacheBlock]]) -> None: - ordered_blocks = self._merge_blocks_by_eviction_order(blocks) + ordered_blocks = [] + for blocks_one_layer in blocks: + ordered_blocks.extend(blocks_one_layer) self.block_pool.free_blocks(ordered_blocks) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 17c44ec279c5..6d7a7f0e0b0a 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: import numpy as np @@ -16,9 +16,6 @@ from vllm.sampling_params import SamplingParams from vllm.v1.request import Request -MayMultiGroupBlockIDs = Union[list[int], list[list[int]]] -MayMultiGroupInt = Union[int, list[int]] - @dataclass class NewRequestData: @@ -30,7 +27,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: MayMultiGroupBlockIDs + block_ids: list[list[int]] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -38,7 +35,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: MayMultiGroupBlockIDs, + block_ids: list[list[int]], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -63,7 +60,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: MayMultiGroupBlockIDs + new_block_ids: list[list[int]] num_computed_tokens: int @classmethod @@ -72,7 +69,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: MayMultiGroupBlockIDs, + new_block_ids: list[list[int]], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -111,7 +108,7 @@ class SchedulerOutput: scheduled_encoder_inputs: dict[str, list[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. - num_common_prefix_blocks: MayMultiGroupInt + num_common_prefix_blocks: list[int] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c5492ea9e6e6..54e53b5792d1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -15,10 +15,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import init_kv_cache_manager +from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, - MayMultiGroupBlockIDs, NewRequestData, +from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, @@ -76,7 +75,7 @@ def __init__( assert num_gpu_blocks is not None and num_gpu_blocks > 0 # Create the KV cache manager. - self.kv_cache_manager = init_kv_cache_manager( + self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, @@ -156,7 +155,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, MayMultiGroupBlockIDs] = {} + req_to_new_block_ids: dict[str, list[list[int]]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -430,11 +429,9 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - if len(self.kv_cache_config.kv_cache_groups) > 1: - num_common_prefix_blocks: Union[int, list[int]] = [0] * len( - self.kv_cache_config.kv_cache_groups) - else: - num_common_prefix_blocks = 0 + num_common_prefix_blocks: list[int] = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -516,7 +513,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: MayMultiGroupBlockIDs, + new_block_ids: list[list[int]], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0921ea35d487..1e8e93ea3ee1 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Concatenate, ParamSpec, Union +from typing import Callable, Concatenate, ParamSpec import numpy as np import torch @@ -156,15 +156,11 @@ def initialize_block_table( pin_memory: bool, device: torch.device, kv_cache_config: KVCacheConfig, -) -> Union[BlockTable, MultiLayerBlockTable]: +) -> MultiLayerBlockTable: max_num_blocks_per_req = [ cdiv(max_model_len, g.kv_cache_spec.block_size) for g in kv_cache_config.kv_cache_groups ] - if len(kv_cache_config.kv_cache_groups) == 1: - return BlockTable(max_num_reqs, max_num_blocks_per_req[0], - max_num_tokens, pin_memory, device) - else: - return MultiLayerBlockTable(max_num_reqs, max_num_blocks_per_req, - max_num_tokens, pin_memory, device, - kv_cache_config) + return MultiLayerBlockTable(max_num_reqs, max_num_blocks_per_req, + max_num_tokens, pin_memory, device, + kv_cache_config) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9443420afe33..5afc6e7757f1 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,7 +11,6 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values -from vllm.v1.core.sched.output import MayMultiGroupBlockIDs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata @@ -32,7 +31,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: MayMultiGroupBlockIDs + block_ids: list[list[int]] num_computed_tokens: int output_token_ids: list[int] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aeda9dec65e6..113aad3ff236 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -372,17 +372,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - if len(self.kv_cache_config.kv_cache_groups) == 1: - block_ids = cast(list[int], req_state.block_ids) - new_block_ids = cast(list[int], req_data.new_block_ids) - block_ids.extend(new_block_ids) - else: - hybrid_block_ids = cast(list[list[int]], - req_state.block_ids) - new_hybrid_block_ids = cast(list[list[int]], - req_data.new_block_ids) - for i in range(len(self.kv_cache_config.kv_cache_groups)): - hybrid_block_ids[i].extend(new_hybrid_block_ids[i]) + for i in range(len(self.kv_cache_config.kv_cache_groups)): + req_state.block_ids[i].extend(req_data.new_block_ids[i]) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -444,9 +435,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Some attention backends (namely MLA) may want to separate requests # based on if the attention computation will be compute-bound or # memory-bound. This gives them a hook to do that. - batch_reordered = self.attn_metadata_builder.reorder_batch( + batch_reordered = self.attn_metadata_builders[0].reorder_batch( self.input_batch, scheduler_output) + for kv_cache_group_id in range( + 1, len(self.kv_cache_config.kv_cache_groups)): + assert not self.attn_metadata_builders[ + kv_cache_group_id].reorder_batch(self.input_batch, + scheduler_output) + if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -460,14 +457,6 @@ def _prepare_inputs( num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: - self.input_batch.refresh_sampling_metadata() - # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -521,16 +510,11 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - if len(self.kv_cache_config.kv_cache_groups) == 1: - may_multi_layer_unwrapper = lambda x, _group_id: x - else: - may_multi_layer_unwrapper = lambda x, group_id: x[group_id] - for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table: BlockTable = may_multi_layer_unwrapper( - self.input_batch.block_table, kv_cache_group_id) + block_table: BlockTable = self.input_batch.block_table[ + kv_cache_group_id] # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] @@ -584,23 +568,19 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - may_multi_layer_unwrapper( - scheduler_output.num_common_prefix_blocks, - kv_cache_group_id), + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], kv_cache_group_spec.kv_cache_spec, - self.attn_backends[kv_cache_group_id], + self.attn_metadata_builders[kv_cache_group_id], ) - block_table = may_multi_layer_unwrapper( - self.input_batch.block_table, kv_cache_group_id) + block_table = self.input_batch.block_table[kv_cache_group_id] attn_metadata_i = ( self.attn_metadata_builders[kv_cache_group_id].build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - block_table=block_table, - )) + common_prefix_len=common_prefix_len)) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -640,7 +620,7 @@ def _compute_cascade_attn_prefix_len( num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, - attn_backend: AttentionBackend, + attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -1786,12 +1766,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") - if isinstance(self.input_batch.block_table, BlockTable): - block_table = self.input_batch.block_table - else: - block_table = self.input_batch.block_table[i] + block_table_i = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, block_table) + weakref.proxy(self), kv_cache_spec, block_table_i) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) From 7f194669cc812e948e4ab76a267e032db9822747 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 23 Apr 2025 08:30:53 -0700 Subject: [PATCH 09/34] fix bug Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 5 ++-- vllm/v1/core/kv_cache_manager.py | 29 +++++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5a266a133f26..d836c6438f8a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -337,7 +337,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, self.runner.attention_chunk_size, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], - block_table, + block_table_tensor, self.kv_cache_spec.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( @@ -363,7 +363,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_scheduler_metadata=local_scheduler_metadata, ) - use_cascade = common_prefix_len > 0 + # use_cascade = common_prefix_len > 0 + use_cascade = False if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c501cec92d82..ecf83b44eeed 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -83,8 +83,7 @@ def __init__( # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ - str, list[list[BlockHashType]]] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) + str, list[list[BlockHashType]]] = defaultdict(lambda: []) # {req_id: The number of cached blocks for each kv cache group} # This is used to track the number of cached blocks for each request. @@ -136,7 +135,7 @@ def get_computed_blocks(self, # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: + if len(block_hashes) == 0: block_hashes = [ hash_request_tokens(self.caching_hash_fn, g.kv_cache_spec.block_size, request, i) @@ -165,13 +164,22 @@ def get_computed_blocks(self, # last_block_hash = block_hashes.pop() # else: # last_block_hash = None + last_block_hashs: list[Optional[BlockHashType]] = [] + for i in range(self.num_kv_cache_groups): + if len( + block_hashes[i] + ) * self.specialized_managers[i].block_size == request.num_tokens: + last_block_hashs.append(block_hashes[i].pop()) + else: + last_block_hashs.append(None) computed_blocks, num_computed_tokens = self.find_longest_cache_hit( request.request_id, block_hashes) - # if last_block_hash is not None: - # # Add back the last block hash if it was removed. - # block_hashes.append(last_block_hash) + for i in range(self.num_kv_cache_groups): + last_block_hash = last_block_hashs[i] + if last_block_hash is not None: + block_hashes[i].append(last_block_hash) if self.log_stats: assert self.prefix_cache_stats is not None @@ -445,11 +453,14 @@ def find_longest_cache_hit( # Use copy to avoid modifying the original block_hashes block_hashes = [block_hash.copy() for block_hash in block_hashes] + def shrink_length(block_hashes, length): + del block_hashes[length:] + while max(num_computed_tokens) != min_computed_tokens: for i, manager in enumerate(self.specialized_managers): if num_computed_tokens[i] > min_computed_tokens: - del block_hashes[i][:min_computed_tokens // - manager.block_size] + shrink_length(block_hashes[i], + min_computed_tokens // manager.block_size) computed_blocks_group_i = ( manager.find_longest_cache_hit_multiple_calls( request_id, block_hashes[i])) @@ -458,6 +469,8 @@ def find_longest_cache_hit( manager.block_size min_computed_tokens = min(min_computed_tokens, num_computed_tokens[i]) + shrink_length(block_hashes[i], + num_computed_tokens[i] // manager.block_size) # Get the non-constlist computed blocks computed_blocks = [ From 34ba571b8a6813235d735b36672e3bcd3aba5295 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 23 Apr 2025 10:16:46 -0700 Subject: [PATCH 10/34] 1 hash per block_size Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 53 +++++++++++++-------- vllm/v1/core/kv_cache_manager.py | 71 ++++++++++++++++++----------- vllm/v1/core/kv_cache_utils.py | 30 ++++-------- vllm/v1/core/specialized_manager.py | 19 ++++---- 4 files changed, 100 insertions(+), 73 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 22613b9dd8d3..e7c1736a24fd 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -26,7 +26,8 @@ class BlockPool: enable_caching: Whether to enable prefix caching. """ - def __init__(self, num_gpu_blocks: int, enable_caching: bool): + def __init__(self, num_gpu_blocks: int, enable_caching: bool, + num_kv_cache_groups: int): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching @@ -48,16 +49,23 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashType, dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ + int, KVCacheBlock]]] = [ + defaultdict(dict) for _ in range(num_kv_cache_groups) + ] # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to # avoid freeing it. self.null_block = self.free_block_queue.popleft() - def get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: + self.num_kv_cache_groups = num_kv_cache_groups + + def get_cached_block( + self, + block_hash: BlockHashType, + kv_cache_group_id: int, + ) -> Optional[KVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -67,7 +75,8 @@ def get_cached_block(self, Returns: The cached block if it exists, or None. """ - cached_blocks = self.cached_block_hash_to_block.get(block_hash) + cached_blocks = self.cached_block_hash_to_block[kv_cache_group_id].get( + block_hash) if not cached_blocks: return None first_block_id = next(iter(cached_blocks)) @@ -82,7 +91,7 @@ def cache_full_blocks( num_full_blocks: int, block_size: int, hash_fn: Callable, - kv_cache_group_id: int = -1, + kv_cache_group_id: int, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash @@ -102,8 +111,7 @@ def cache_full_blocks( be cached after this function. block_size: Number of tokens in each block. hash_fn: The hash function to use for block hashes. - kv_cache_group_id: The id of the kv cache group. -1 means no kv - cache group. + kv_cache_group_id: The id of the kv cache group. """ if num_cached_blocks == num_full_blocks: return @@ -124,7 +132,8 @@ def cache_full_blocks( if i < len(new_block_hashes): # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by + # "get_computed_blocks" or other groups with the same block_size + # if the tokens are not generated by # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. @@ -146,8 +155,7 @@ def cache_full_blocks( # we reach to this branch only when the block is completed with # generated tokens, we only need to consider the last mm input. extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1, - kv_cache_group_id) + request, start_token_idx, end_token_idx, -1) # Compute the hash of the current block. block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, @@ -156,7 +164,9 @@ def cache_full_blocks( # Update and added the full block to the cache. blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + blk.kv_cache_group_id = kv_cache_group_id + self.cached_block_hash_to_block[kv_cache_group_id][block_hash][ + blk.block_id] = blk prev_block_hash_value = block_hash.hash_value def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -203,12 +213,17 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: + kv_cache_group_id = block.kv_cache_group_id + if block_hash and block_hash in self.cached_block_hash_to_block[ + kv_cache_group_id]: block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] + del self.cached_block_hash_to_block[kv_cache_group_id][block_hash][ + block.block_id] - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] + if len(self.cached_block_hash_to_block[kv_cache_group_id] + [block_hash]) == 0: + del self.cached_block_hash_to_block[kv_cache_group_id][ + block_hash] return True return False @@ -259,7 +274,9 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = [ + defaultdict(dict) for _ in range(self.num_kv_cache_groups) + ] # Remove all hashes from all blocks. for block in self.blocks: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ecf83b44eeed..9f40a0fea8a2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -62,16 +62,18 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, + self.num_kv_cache_groups) self.specialized_managers = [ get_specialized_manager( kv_cache_spec=g.kv_cache_spec, block_pool=self.block_pool, - ) for g in kv_cache_config.kv_cache_groups + kv_cache_group_id=i, + ) for i, g in enumerate(kv_cache_config.kv_cache_groups) ] - self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) - # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. @@ -82,8 +84,13 @@ def __init__( # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[list[BlockHashType]]] = defaultdict(lambda: []) + # block_size -> list[BlockHashType]; TODO update comment + self.req_to_block_hashes: defaultdict[str, dict[ + int, list[BlockHashType]]] = defaultdict(dict) + + self.all_block_sizes = set( + g.kv_cache_spec.block_size + for g in self.kv_cache_config.kv_cache_groups) # {req_id: The number of cached blocks for each kv cache group} # This is used to track the number of cached blocks for each request. @@ -136,11 +143,11 @@ def get_computed_blocks(self, # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if len(block_hashes) == 0: - block_hashes = [ - hash_request_tokens(self.caching_hash_fn, - g.kv_cache_spec.block_size, request, i) - for i, g in enumerate(self.kv_cache_config.kv_cache_groups) - ] + block_hashes = { + block_size: + hash_request_tokens(self.caching_hash_fn, block_size, request) + for block_size in self.all_block_sizes + } self.req_to_block_hashes[request.request_id] = block_hashes if self.log_stats: @@ -164,22 +171,20 @@ def get_computed_blocks(self, # last_block_hash = block_hashes.pop() # else: # last_block_hash = None - last_block_hashs: list[Optional[BlockHashType]] = [] + last_block_hashs: dict[int, BlockHashType] = {} for i in range(self.num_kv_cache_groups): - if len( - block_hashes[i] - ) * self.specialized_managers[i].block_size == request.num_tokens: - last_block_hashs.append(block_hashes[i].pop()) - else: - last_block_hashs.append(None) + block_size = self.specialized_managers[i].block_size + if len(block_hashes[block_size] + ) * block_size == request.num_tokens: + last_block_hashs[block_size] = block_hashes[block_size].pop() computed_blocks, num_computed_tokens = self.find_longest_cache_hit( request.request_id, block_hashes) for i in range(self.num_kv_cache_groups): - last_block_hash = last_block_hashs[i] - if last_block_hash is not None: - block_hashes[i].append(last_block_hash) + block_size = self.specialized_managers[i].block_size + if block_size in last_block_hashs: + block_hashes[block_size].append(last_block_hashs[block_size]) if self.log_stats: assert self.prefix_cache_stats is not None @@ -326,17 +331,19 @@ def allocate_slots( # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. for i in range(self.num_kv_cache_groups): + block_size = self.specialized_managers[i].block_size num_full_blocks_after_append = ( - num_computed_tokens + num_tokens - len(request.spec_token_ids) - ) // self.specialized_managers[i].block_size + num_computed_tokens + num_tokens - + len(request.spec_token_ids)) // block_size self.block_pool.cache_full_blocks( request=request, blocks=req_blocks[i], - block_hashes=self.req_to_block_hashes[request.request_id][i], + block_hashes=self.req_to_block_hashes[request.request_id] + [block_size], num_cached_blocks=num_cached_blocks[i], num_full_blocks=num_full_blocks_after_append, - block_size=self.specialized_managers[i].block_size, + block_size=block_size, hash_fn=self.caching_hash_fn, kv_cache_group_id=i, ) @@ -439,11 +446,18 @@ def free_block_hashes(self, request: Request) -> None: self.req_to_block_hashes.pop(request.request_id, None) def find_longest_cache_hit( - self, request_id: str, block_hashes: list[list[BlockHashType]] + self, request_id: str, block_hashes_dict: dict[int, + list[BlockHashType]] ) -> tuple[list[list[KVCacheBlock]], int]: """Find the longest cache hit for each kv cache group. TODO: add more notes """ + if self.num_kv_cache_groups == 1: + block_size = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + hit_blocks = self.specialized_managers[0].find_longest_cache_hit( + request_id, block_hashes_dict[block_size]) + return [hit_blocks], len(hit_blocks) * block_size # TODO: accelerate by make full attention the first layer # TODO: add note for the two magic number num_computed_tokens = [self.max_model_len + 100] * len( @@ -451,7 +465,10 @@ def find_longest_cache_hit( min_computed_tokens = self.max_model_len # Use copy to avoid modifying the original block_hashes - block_hashes = [block_hash.copy() for block_hash in block_hashes] + block_hashes = [ + block_hashes_dict[g.kv_cache_spec.block_size].copy() + for g in self.kv_cache_config.kv_cache_groups + ] def shrink_length(block_hashes, length): del block_hashes[length:] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 31818e96a51a..7f2bd0f6de75 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -126,6 +126,8 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + kv_cache_group_id: int = -1 + def incr_ref(self): self.ref_cnt += 1 @@ -145,6 +147,7 @@ def block_hash(self, block_hash: BlockHashType): def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + self.kv_cache_group_id = -1 def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ @@ -265,12 +268,11 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]: return ret -def need_extra_keys(request: Request, kv_cache_group_id: int) -> bool: +def need_extra_keys(request: Request) -> bool: """Check whether the blocks allocated to this request need extra hash keys. Args: request (Request): The request. - kv_cache_group_id (int): The id of the kv cache group. -1 means no kv cache group. Returns: bool: Whether blocks allocated to this request need extra hash keys. @@ -278,9 +280,7 @@ def need_extra_keys(request: Request, kv_cache_group_id: int) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. - return bool(request.mm_positions) or (request.lora_request - is not None) or (kv_cache_group_id - != -1) + return bool(request.mm_positions) or (request.lora_request is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -368,8 +368,7 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int, - kv_cache_group_id: int) -> tuple[Optional[tuple[Any, ...]], int]: + start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -378,8 +377,6 @@ def generate_block_hash_extra_keys( start_token_idx: The start token index of the block. end_token_idx: The end token index of the block. start_mm_idx: The start multi-modal index of the block. - kv_cache_group_id: The id of the kv cache group. -1 means no kv cache - group. Returns: A tuple of extra keys and the next multi-modal index. """ @@ -389,8 +386,6 @@ def generate_block_hash_extra_keys( lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys - if kv_cache_group_id != -1: - extra_keys.append(kv_cache_group_id) if not extra_keys: return None, new_start_mm_idx @@ -414,8 +409,6 @@ def hash_block_tokens( if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. - kv_cache_group_id: The id of the kv cache group. -1 means no kv cache group. - extra_keys: Extra keys for the block. Returns: The hash value of the block and the token ids in the block. @@ -431,24 +424,21 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, - block_size: int, - request: Request, - kv_cache_group_id: int = -1) -> list[BlockHashType]: +def hash_request_tokens(hash_function: Any, block_size: int, + request: Request) -> list[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. request: The request object. - kv_cache_group_id: The id of the kv cache group. -1 means no kv cache group. Returns: The list of computed hash values. """ token_ids = request.all_token_ids - req_need_extra_keys = need_extra_keys(request, kv_cache_group_id) + req_need_extra_keys = need_extra_keys(request) req_extra_keys = None curr_mm_idx = 0 @@ -464,7 +454,7 @@ def hash_request_tokens(hash_function: Any, if req_need_extra_keys: # MM and LoRA requests need extra keys for block-hash computation. req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx, kv_cache_group_id) + request, start, end, curr_mm_idx) block_hash = hash_block_tokens(hash_function, parent_block_hash_value, block_token_ids, req_extra_keys) diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 97a5d2fa88f0..ababd79b7cbe 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -20,17 +20,20 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + kv_cache_group_id: int, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. + kv_cache_group_id: The id of the kv cache group. """ self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + self.kv_cache_group_id = kv_cache_group_id # for caching the intermediate states between multiple calls of # find_longest_cache_hit self.req_cached_blocks: dict[str, list[KVCacheBlock]] = {} @@ -124,7 +127,7 @@ def _find_longest_cache_hit( # not in the cached_block_hash_to_id, the following block hashes # are not computed yet for sure. if cached_block := self.block_pool.get_cached_block( - block_hash): + block_hash, self.kv_cache_group_id): computed_blocks.append(cached_block) else: break @@ -141,9 +144,9 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class SlidingWindowManager(SpecializedManager): - def __init__(self, kv_cache_spec: SlidingWindowSpec, - block_pool: BlockPool): - super().__init__(kv_cache_spec, block_pool) + def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, + kv_cache_group_id: int): + super().__init__(kv_cache_spec, block_pool, kv_cache_group_id) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -176,7 +179,7 @@ def _find_longest_cache_hit( # Search from right to left and early stop when a match is found. for i in range(len(block_hashes) - num_contiguous_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): + block_hashes[i], self.kv_cache_group_id): computed_blocks[i] = cached_block num_contiguous_blocks += 1 if (num_contiguous_blocks @@ -218,8 +221,8 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], } -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - block_pool: BlockPool) -> SpecializedManager: +def get_specialized_manager(kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + kv_cache_group_id: int) -> SpecializedManager: manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool) + manager = manager_class(kv_cache_spec, block_pool, kv_cache_group_id) return manager From 18245e385c15fa146e4056118054ae7a2a5662c9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 24 Apr 2025 09:43:28 -0700 Subject: [PATCH 11/34] one manager for each type Signed-off-by: Chen Zhang --- examples/offline_inference/basic/basic.py | 3 +- .../v1/e2e/test_correctness_sliding_window.py | 6 +- vllm/v1/core/block_pool.py | 87 ++-- vllm/v1/core/kv_cache_manager.py | 391 ++++++++---------- vllm/v1/core/kv_cache_utils.py | 35 +- vllm/v1/core/sched/scheduler.py | 6 +- vllm/v1/core/specialized_manager.py | 188 ++++++--- 7 files changed, 382 insertions(+), 334 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb4834..5afe91783c55 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -15,7 +15,8 @@ def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + # llm = LLM(model="facebook/opt-125m") + llm = LLM(model="google/gemma-3-1b-it", enforce_eager=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index a125d3fb7975..a2ec9793d2f8 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -17,15 +17,15 @@ class TestConfig: model_config = { "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), - "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), + "google/gemma-3-1b-it": TestConfig(4096, (400, 800)), # TODO: swa 1024 } @pytest.mark.parametrize( "model", [ - "bigcode/starcoder2-3b", # sliding window only - "google/gemma-2-2b-it", # sliding window + full attention + # "bigcode/starcoder2-3b", # sliding window only + "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index e7c1736a24fd..88897b84ad86 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + GroupedKVCacheBlock, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -27,10 +27,11 @@ class BlockPool: """ def __init__(self, num_gpu_blocks: int, enable_caching: bool, - num_kv_cache_groups: int): + num_specialized_managers: int, caching_hash_fn: Callable): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching + self.caching_hash_fn = caching_hash_fn # All kv-cache blocks. self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) @@ -40,7 +41,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool, # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is + # {manager_id: {block_hash: {block ID: list[block]}}}. A cached block is # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. @@ -50,8 +51,8 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool, # we want to make sure the allocated block IDs won't change so that # block tables are append-only. self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ - int, KVCacheBlock]]] = [ - defaultdict(dict) for _ in range(num_kv_cache_groups) + int, GroupedKVCacheBlock]]] = [ + defaultdict(dict) for _ in range(num_specialized_managers) ] # To represent a placeholder block with block_id=0. @@ -59,13 +60,13 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool, # avoid freeing it. self.null_block = self.free_block_queue.popleft() - self.num_kv_cache_groups = num_kv_cache_groups + self.num_specialized_managers = num_specialized_managers def get_cached_block( self, block_hash: BlockHashType, - kv_cache_group_id: int, - ) -> Optional[KVCacheBlock]: + manager_id: int, + ) -> Optional[GroupedKVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -75,7 +76,7 @@ def get_cached_block( Returns: The cached block if it exists, or None. """ - cached_blocks = self.cached_block_hash_to_block[kv_cache_group_id].get( + cached_blocks = self.cached_block_hash_to_block[manager_id].get( block_hash) if not cached_blocks: return None @@ -85,13 +86,12 @@ def get_cached_block( def cache_full_blocks( self, request: Request, - blocks: list[KVCacheBlock], + blocks: list[GroupedKVCacheBlock], block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, block_size: int, - hash_fn: Callable, - kv_cache_group_id: int, + manager_id: int, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash @@ -110,8 +110,7 @@ def cache_full_blocks( num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. - hash_fn: The hash function to use for block hashes. - kv_cache_group_id: The id of the kv cache group. + manager_id: The id of the kv cache manager. """ if num_cached_blocks == num_full_blocks: return @@ -128,6 +127,7 @@ def cache_full_blocks( prev_block_hash_value = prev_block.block_hash.hash_value for i, blk in enumerate(new_full_blocks): + assert all(b.block_hash is None for b in blk.blocks) assert blk.block_hash is None if i < len(new_block_hashes): @@ -158,14 +158,17 @@ def cache_full_blocks( request, start_token_idx, end_token_idx, -1) # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, + block_hash = hash_block_tokens(self.caching_hash_fn, + prev_block_hash_value, block_tokens, extra_keys) block_hashes.append(block_hash) # Update and added the full block to the cache. + for b in blk.blocks: + b.block_hash = block_hash + b.manager_id = manager_id blk.block_hash = block_hash - blk.kv_cache_group_id = kv_cache_group_id - self.cached_block_hash_to_block[kv_cache_group_id][block_hash][ + self.cached_block_hash_to_block[manager_id][block_hash][ blk.block_id] = blk prev_block_hash_value = block_hash.hash_value @@ -213,22 +216,23 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - kv_cache_group_id = block.kv_cache_group_id + manager_id = block.manager_id if block_hash and block_hash in self.cached_block_hash_to_block[ - kv_cache_group_id]: - block.reset_hash() - del self.cached_block_hash_to_block[kv_cache_group_id][block_hash][ + manager_id]: + for b in self.cached_block_hash_to_block[manager_id][block_hash][ + block.block_id].blocks: + b.reset_hash() + del self.cached_block_hash_to_block[manager_id][block_hash][ block.block_id] - if len(self.cached_block_hash_to_block[kv_cache_group_id] + if len(self.cached_block_hash_to_block[manager_id] [block_hash]) == 0: - del self.cached_block_hash_to_block[kv_cache_group_id][ - block_hash] + del self.cached_block_hash_to_block[manager_id][block_hash] return True return False - def touch(self, blocks: list[KVCacheBlock]) -> None: + def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -236,14 +240,17 @@ def touch(self, blocks: list[KVCacheBlock]) -> None: Args: blocks: A list of blocks to touch. """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.remove(block) - block.incr_ref() - - def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + for block_one_layer in blocks: + for block_two_layer in block_one_layer: + for block in block_two_layer.blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, + ordered_blocks: Iterable[GroupedKVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -251,11 +258,13 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - for block in ordered_blocks: - block.decr_ref() - # null_block should not be added to the free list. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.append(block) + # TODO: make sure blocks in the first group are evicted first + for blk in ordered_blocks: + for block in blk.blocks: + block.decr_ref() + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -275,7 +284,7 @@ def reset_prefix_cache(self) -> bool: # Remove all hashes so that no new blocks will hit. self.cached_block_hash_to_block = [ - defaultdict(dict) for _ in range(self.num_kv_cache_groups) + defaultdict(dict) for _ in range(self.num_specialized_managers) ] # Remove all hashes from all blocks. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9f40a0fea8a2..5962d953c04c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -7,10 +7,11 @@ from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.core.kv_cache_utils import ( + BlockHashType, GroupedKVCacheBlock, KVCacheBlock, hash_request_tokens, + remove_last_block_hash_for_divisible_prompt_length) +from vllm.v1.core.specialized_manager import SpecializedManager, get_specialized_manager +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -19,11 +20,7 @@ @dataclass class KVCacheBlocks: - blocks: list[list[KVCacheBlock]] - - def to_block_ids(self) -> list[list[int]]: - return [[blk.block_id for blk in blk_one_layer] - for blk_one_layer in self.blocks] + blocks: list[list[GroupedKVCacheBlock]] def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": return KVCacheBlocks([ @@ -64,23 +61,23 @@ def __init__( self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, - self.num_kv_cache_groups) - self.specialized_managers = [ - get_specialized_manager( - kv_cache_spec=g.kv_cache_spec, - block_pool=self.block_pool, - kv_cache_group_id=i, - ) for i, g in enumerate(kv_cache_config.kv_cache_groups) - ] - - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[ - str, list[list[KVCacheBlock]]] = defaultdict( - lambda: [[] for _ in range(self.num_kv_cache_groups)]) + self.manager_to_group = self.generate_group_manager_map() + self.num_specialized_managers = len(self.manager_to_group) + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, + self.num_specialized_managers, + self.caching_hash_fn) + + self.specialized_managers: list[SpecializedManager] = [] + for i in range(len(self.manager_to_group)): + group_ids = self.manager_to_group[i] + kv_cache_spec = kv_cache_config.kv_cache_groups[ + group_ids[0]].kv_cache_spec + self.specialized_managers.append( + get_specialized_manager(kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_manager_id=i, + num_kv_cache_groups=len(group_ids))) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. @@ -92,12 +89,6 @@ def __init__( g.kv_cache_spec.block_size for g in self.kv_cache_config.kv_cache_groups) - # {req_id: The number of cached blocks for each kv cache group} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, list[int]] = {} - @property def usage(self) -> float: """Get the KV cache usage. @@ -119,6 +110,10 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats + def empty_kv_cache_blocks(self) -> KVCacheBlocks: + return KVCacheBlocks([[] + for _ in range(self.num_specialized_managers)]) + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. @@ -134,10 +129,7 @@ def get_computed_blocks(self, """ if not self.enable_caching: # Prefix caching is disabled. - computed_blocks: list[list[KVCacheBlock]] = [ - [] for _ in range(self.num_kv_cache_groups) - ] - return KVCacheBlocks(computed_blocks), 0 + return self.empty_kv_cache_blocks(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -155,36 +147,11 @@ def get_computed_blocks(self, self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: - return KVCacheBlocks([[] - for _ in range(self.num_kv_cache_groups)]), 0 - - # TODO: Fix last block problem - # if len(block_hashes) * self.block_size == request.num_tokens: - # # When prompt length is divisible by the block size and all - # # blocks are cached, we need to recompute the last token. This - # # have to be achieved by re-computing an entire block because - # # allocate_slots() assumes num_computed_tokens is always a - # # multiple of the block size. To achieve this, remove the last - # # block hash from the block_hashes for find_longest_cache_hit - # # This limitation can potentially be removed in the future to - # # slightly improve the performance. - # last_block_hash = block_hashes.pop() - # else: - # last_block_hash = None - last_block_hashs: dict[int, BlockHashType] = {} - for i in range(self.num_kv_cache_groups): - block_size = self.specialized_managers[i].block_size - if len(block_hashes[block_size] - ) * block_size == request.num_tokens: - last_block_hashs[block_size] = block_hashes[block_size].pop() + return self.empty_kv_cache_blocks(), 0 computed_blocks, num_computed_tokens = self.find_longest_cache_hit( - request.request_id, block_hashes) + request, block_hashes) - for i in range(self.num_kv_cache_groups): - block_size = self.specialized_managers[i].block_size - if block_size in last_block_hashs: - block_hashes[block_size].append(last_block_hashs[block_size]) if self.log_stats: assert self.prefix_cache_stats is not None @@ -196,7 +163,7 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[KVCacheBlocks] = None, + wrapped_new_computed_blocks: Optional[KVCacheBlocks] = None, num_new_computed_tokens: int = 0, num_lookahead_tokens: int = 0, ) -> Optional[KVCacheBlocks]: @@ -233,123 +200,68 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - if new_computed_blocks is not None: - new_computed_block_list = new_computed_blocks.blocks + if wrapped_new_computed_blocks is not None: + new_computed_blocks = wrapped_new_computed_blocks.blocks else: - new_computed_block_list = ([ - [] for _ in range(self.num_kv_cache_groups) - ]) - - req_blocks = self.req_to_blocks[request.request_id] - + new_computed_blocks = [ + [] for _ in range(self.num_specialized_managers) + ] # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = [ - manager.remove_skipped_blocks(req_blocks[i], + for i, manager in enumerate(self.specialized_managers): + manager.remove_skipped_blocks(request.request_id, request.num_computed_tokens) - for i, manager in enumerate(self.specialized_managers) - ] - self._free_blocks(removed_blocks) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + num_new_computed_tokens) - - num_new_blocks: list[int] = [] - for i in range(self.num_kv_cache_groups): - num_required_blocks_i = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.specialized_managers[i].block_size) - num_new_blocks.append(num_required_blocks_i - len(req_blocks[i]) - - len(new_computed_block_list[i])) - total_num_new_blocks = sum(max(x, 0) for x in num_new_blocks) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum( - 1 for blk_one_layer in new_computed_block_list - for blk in blk_one_layer if blk.ref_cnt == 0) - if (total_num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): - # Cannot allocate new blocks + num_tokens_need_slot = (num_computed_tokens + num_tokens + + num_lookahead_tokens) + + num_needed_blocks: list[int] = [ + manager.get_num_needed_blocks(request.request_id, + num_tokens_need_slot, + new_computed_blocks[i]) + for manager in self.specialized_managers + ] + if (sum(num_needed_blocks) > self.block_pool.get_num_free_blocks()): return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - for blocks in new_computed_block_list: - self.block_pool.touch(blocks) + self.block_pool.touch(new_computed_blocks) else: - assert all(len(blks) == 0 for blks in new_computed_block_list), ( + assert all(len(blks) == 0 for blks in new_computed_blocks), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - for i in range(self.num_kv_cache_groups): - req_blocks[i].extend(new_computed_block_list[i]) + for i in range(self.num_specialized_managers): + self.specialized_managers[i].req_to_blocks[ + request.request_id].extend(new_computed_blocks[i]) - new_blocks: list[list[KVCacheBlock]] = [] + new_blocks: list[list[GroupedKVCacheBlock]] = [] # Start to handle new blocks - for i in range(self.num_kv_cache_groups): - if num_new_blocks[i] <= 0: - # No new block is needed. - new_blocks.append([]) - else: - # Get new blocks from the free block pool. - num_new_blocks_i = min( - num_new_blocks[i], - # Should not exceed the maximum number of blocks per - # request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req[i] - len(req_blocks[i]), - ) - assert num_new_blocks_i > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks_this_layer = self.block_pool.get_new_blocks( - num_new_blocks_i) - new_blocks.append(new_blocks_this_layer) - req_blocks[i].extend(new_blocks_this_layer) + for i in range(self.num_specialized_managers): + new_blocks_i = self.specialized_managers[i].allocate_new_blocks( + request.request_id, num_tokens_need_slot) + new_blocks.append(new_blocks_i) if not self.enable_caching: return KVCacheBlocks(new_blocks) - # Use `new_computed_block_list` for a new request, and - # `num_cached_block` for a running request. - num_cached_blocks = self.num_cached_block.get( - request.request_id, - [len(blocks) for blocks in new_computed_block_list]) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. - for i in range(self.num_kv_cache_groups): - block_size = self.specialized_managers[i].block_size - num_full_blocks_after_append = ( - num_computed_tokens + num_tokens - - len(request.spec_token_ids)) // block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks[i], - block_hashes=self.req_to_block_hashes[request.request_id] - [block_size], - num_cached_blocks=num_cached_blocks[i], - num_full_blocks=num_full_blocks_after_append, - block_size=block_size, - hash_fn=self.caching_hash_fn, - kv_cache_group_id=i, - ) - num_cached_blocks[i] = num_full_blocks_after_append - - self.num_cached_block[request.request_id] = num_cached_blocks + for i, manager in enumerate(self.specialized_managers): + manager.cache_blocks( + request, new_computed_blocks[i], self.req_to_block_hashes[ + request.request_id][manager.block_size], + num_computed_tokens, num_tokens) + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: @@ -360,14 +272,8 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, None) - if blocks is not None: - # Reverse the blocks so that the tail blocks can have higher - # eviction priority. - self._free_blocks([list(reversed(blks)) for blks in blocks]) - - self.num_cached_block.pop(request.request_id, None) + for manager in self.specialized_managers: + manager.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -424,18 +330,20 @@ def get_num_common_prefix_blocks( list[int]: The number of common prefix blocks for each kv cache group. """ - assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = [] - for i in range(self.num_kv_cache_groups): - num_common_blocks_i = 0 - for block in blocks[i]: - if block.ref_cnt == num_running_requests: - num_common_blocks_i += 1 - else: - break - num_common_blocks.append(num_common_blocks_i) - return num_common_blocks + # TODO: implement this + return [0] * self.num_kv_cache_groups + # assert request.status == RequestStatus.RUNNING + # blocks = self.req_to_blocks[request.request_id] + # num_common_blocks = [] + # for i in range(self.num_kv_cache_groups): + # num_common_blocks_i = 0 + # for block in blocks[i]: + # if block.ref_cnt == num_running_requests: + # num_common_blocks_i += 1 + # else: + # break + # num_common_blocks.append(num_common_blocks_i) + # return num_common_blocks def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -446,63 +354,96 @@ def free_block_hashes(self, request: Request) -> None: self.req_to_block_hashes.pop(request.request_id, None) def find_longest_cache_hit( - self, request_id: str, block_hashes_dict: dict[int, - list[BlockHashType]] - ) -> tuple[list[list[KVCacheBlock]], int]: + self, request: Request, block_hashes_dict: dict[int, + list[BlockHashType]] + ) -> tuple[list[list[GroupedKVCacheBlock]], int]: """Find the longest cache hit for each kv cache group. TODO: add more notes """ - if self.num_kv_cache_groups == 1: - block_size = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size - hit_blocks = self.specialized_managers[0].find_longest_cache_hit( - request_id, block_hashes_dict[block_size]) - return [hit_blocks], len(hit_blocks) * block_size - # TODO: accelerate by make full attention the first layer - # TODO: add note for the two magic number - num_computed_tokens = [self.max_model_len + 100] * len( - self.specialized_managers) - min_computed_tokens = self.max_model_len - - # Use copy to avoid modifying the original block_hashes - block_hashes = [ - block_hashes_dict[g.kv_cache_spec.block_size].copy() - for g in self.kv_cache_config.kv_cache_groups - ] - - def shrink_length(block_hashes, length): - del block_hashes[length:] - - while max(num_computed_tokens) != min_computed_tokens: - for i, manager in enumerate(self.specialized_managers): - if num_computed_tokens[i] > min_computed_tokens: - shrink_length(block_hashes[i], - min_computed_tokens // manager.block_size) - computed_blocks_group_i = ( - manager.find_longest_cache_hit_multiple_calls( - request_id, block_hashes[i])) - - num_computed_tokens[i] = len(computed_blocks_group_i) * \ - manager.block_size - min_computed_tokens = min(min_computed_tokens, - num_computed_tokens[i]) - shrink_length(block_hashes[i], - num_computed_tokens[i] // manager.block_size) - - # Get the non-constlist computed blocks - computed_blocks = [ - manager.find_longest_cache_hit(request_id, block_hashes[i]) - for i, manager in enumerate(self.specialized_managers) - ] - assert all( - len(block) * manager.block_size == min_computed_tokens for block, - manager in zip(computed_blocks, self.specialized_managers)) + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, remove the last + # block hash from the block_hashes for find_longest_cache_hit + # This limitation can potentially be removed in the future to + # slightly improve the performance. + with remove_last_block_hash_for_divisible_prompt_length( + block_hashes_dict, request.num_tokens): + if self.num_specialized_managers == 1: + block_size = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + hit_blocks = self.specialized_managers[ + 0].find_longest_cache_hit(block_hashes_dict[block_size]) + + return [hit_blocks], len(hit_blocks) * block_size + + # TODO: add note for the two magic number + num_computed_tokens = [self.max_model_len + 100] * len( + self.specialized_managers) + min_computed_tokens = self.max_model_len + + # Use copy to avoid modifying the original block_hashes + block_hashes = [ + block_hashes_dict[g.kv_cache_spec.block_size].copy() + for g in self.kv_cache_config.kv_cache_groups + ] - return computed_blocks, min_computed_tokens + computed_blocks: list[Optional[list[GroupedKVCacheBlock]]] = [ + None for _ in range(self.num_specialized_managers) + ] - def _free_blocks(self, blocks: list[list[KVCacheBlock]]) -> None: - ordered_blocks = [] - for blocks_one_layer in blocks: - ordered_blocks.extend(blocks_one_layer) - self.block_pool.free_blocks(ordered_blocks) + def shrink_length(block_hashes, length): + del block_hashes[length:] + + while max(num_computed_tokens) != min_computed_tokens: + for i, manager in enumerate(self.specialized_managers): + if num_computed_tokens[i] > min_computed_tokens: + shrink_length( + block_hashes[i], + min_computed_tokens // manager.block_size) + computed_blocks_i = (manager.find_longest_cache_hit( + block_hashes[i], computed_blocks[i])) + + num_computed_tokens[i] = len(computed_blocks_i) * \ + manager.block_size + min_computed_tokens = min(min_computed_tokens, + num_computed_tokens[i]) + computed_blocks[i] = computed_blocks_i + shrink_length( + block_hashes[i], + num_computed_tokens[i] // manager.block_size) + + assert all(block is not None and len(block) * + manager.block_size == min_computed_tokens + for block, manager in zip(computed_blocks, + self.specialized_managers)) + return computed_blocks, min_computed_tokens + + def generate_group_manager_map(self) -> list[list[int]]: + type_ids = [ + g.kv_cache_spec.type_id + for g in self.kv_cache_config.kv_cache_groups + ] + assert sorted(type_ids) == type_ids, "type_ids must be sorted" + manager_to_group: list[list[int]] = [] + for i, type_id in enumerate(type_ids): + if type_id == 0: + manager_to_group.append([i]) + else: + if type_id == type_ids[i - 1]: + manager_to_group[-1].append(i) + else: + manager_to_group.append([i]) + return manager_to_group + + def to_block_ids(self, kv_cache_blocks: KVCacheBlocks) -> list[list[int]]: + block_ids: list[list[int]] = [[] + for _ in range(self.num_kv_cache_groups)] + for blocks_one_manager, group_ids in zip(kv_cache_blocks.blocks, + self.manager_to_group): + for blocks in blocks_one_manager: + for blk, group_id in zip(blocks.blocks, group_ids): + block_ids[group_id].append(blk.block_id) + return block_ids diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 7f2bd0f6de75..6d4e2e4b7251 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" +from contextlib import contextmanager import math import os from collections import defaultdict, deque @@ -126,7 +127,7 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None - kv_cache_group_id: int = -1 + manager_id: int = -1 def incr_ref(self): self.ref_cnt += 1 @@ -147,7 +148,7 @@ def block_hash(self, block_hash: BlockHashType): def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None - self.kv_cache_group_id = -1 + self.manager_id = -1 def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ @@ -829,3 +830,33 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): kv_cache_config.num_blocks = min_num_blocks return kv_cache_configs + + +@contextmanager +def remove_last_block_hash_for_divisible_prompt_length( + block_hashes: dict[int, list[BlockHashType]], num_tokens: int): + """ + Remove the last block hash for the case where the prompt length is divisible + by the block size and all blocks are cached. + """ + last_block_hashs: dict[int, BlockHashType] = {} + for block_size in block_hashes: + if len(block_hashes[block_size]) * block_size == num_tokens: + last_block_hashs[block_size] = block_hashes[block_size].pop() + yield + for block_size, block_hash in last_block_hashs.items(): + block_hashes[block_size].append(block_hash) + + +# KVCacheBlocks for the same set of token of groups managed by the same manager +@dataclass +class GroupedKVCacheBlock: + blocks: tuple[KVCacheBlock, ...] + block_hash: Optional[BlockHashType] = None + block_id: int = -1 + + @staticmethod + def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): + return GroupedKVCacheBlock(blocks=blocks, + block_hash=blocks[0].block_hash, + block_id=blocks[0].block_id) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 54e53b5792d1..6db328614bd7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -251,7 +251,7 @@ def schedule(self) -> SchedulerOutput: # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index req_to_new_block_ids[request.request_id] = ( - new_blocks.to_block_ids()) + self.kv_cache_manager.to_block_ids(new_blocks)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -321,6 +321,7 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) + print("num_computed_tokens", num_computed_tokens) # Get externally-cached tokens if using a KVConnector. num_external_tokens = ( @@ -397,7 +398,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = ( - computed_blocks + new_blocks).to_block_ids() + self.kv_cache_manager.to_block_ids(computed_blocks + + new_blocks)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index ababd79b7cbe..661f440b89eb 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections import defaultdict from typing import Optional from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request from vllm.v1.utils import ConstantList @@ -20,63 +22,114 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, - kv_cache_group_id: int, + kv_cache_manager_id: int, + num_kv_cache_groups: int, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. - kv_cache_group_id: The id of the kv cache group. + kv_cache_manager_id: The id of the kv cache manager. """ self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool - self.kv_cache_group_id = kv_cache_group_id - # for caching the intermediate states between multiple calls of - # find_longest_cache_hit - self.req_cached_blocks: dict[str, list[KVCacheBlock]] = {} + self.kv_cache_manager_id = kv_cache_manager_id + self.num_kv_cache_groups = num_kv_cache_groups + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[ + str, list[GroupedKVCacheBlock]] = defaultdict(list) - def find_longest_cache_hit( - self, - request_id: str, - block_hashes: list[BlockHashType], - ) -> list[KVCacheBlock]: - """ - Find the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. + # {req_id: The number of cached blocks for each kv cache group} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} - Args: - request_id: The request id. - block_hashes: The block hashes of the request. - return_const_list: Whether to return a ConstantList. + def get_num_needed_blocks( + self, request_id: str, num_tokens: int, + new_computed_block_list: list[GroupedKVCacheBlock]) -> int: + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = max( + num_required_blocks - len(new_computed_block_list) - + len(self.req_to_blocks[request_id]), 0) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. # TODO: update comment + num_evictable_computed_blocks = sum( + blks.blocks[0].ref_cnt == 0 for blks in new_computed_block_list) + return ((num_new_blocks + num_evictable_computed_blocks) * + self.num_kv_cache_groups) + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[GroupedKVCacheBlock]: + """ + return [group_id][block_of_that_group] """ + # TODO: group? + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = max( + num_required_blocks - len(self.req_to_blocks[request_id]), 0) + if num_new_blocks <= 0: + return [] + else: + flat_new_blocks = self.block_pool.get_new_blocks( + num_new_blocks * self.num_kv_cache_groups) + new_blocks = [] + for i in range(num_new_blocks): + blocks = flat_new_blocks[i * self.num_kv_cache_groups:(i + 1) * + self.num_kv_cache_groups] + grouped_block = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple(blocks)) + new_blocks.append(grouped_block) + self.req_to_blocks[request_id].extend(new_blocks) + return new_blocks - if req_cached_blocks := self.req_cached_blocks.pop(request_id, None): - assert len(req_cached_blocks) >= len(block_hashes) + def cache_blocks(self, request: Request, + new_computed_blocks: list[GroupedKVCacheBlock], + block_hashes: list[BlockHashType], + num_computed_tokens: int, num_tokens: int) -> None: + # Use `new_computed_blocks` for a new request, and + # `num_cached_block` for a running request. + num_cached_blocks = self.num_cached_block.get(request.request_id, + len(new_computed_blocks)) + # Speculated tokens might be rejected in the future, so we does + # not cache any speculated tokens. We only cache blocks with + # generated (accepted) tokens. + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size - req_cached_blocks = self._find_longest_cache_hit( - block_hashes, req_cached_blocks) + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + manager_id=self.kv_cache_manager_id, + ) - return req_cached_blocks + self.num_cached_block[ + request.request_id] = num_full_blocks_after_append - def find_longest_cache_hit_multiple_calls( - self, - request_id: str, - block_hashes: list[BlockHashType], - ) -> ConstantList[KVCacheBlock]: - req_cached_blocks = self.find_longest_cache_hit( - request_id, block_hashes) - self.req_cached_blocks[request_id] = req_cached_blocks - return ConstantList(req_cached_blocks) + def free(self, request_id: str) -> None: + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request_id, None) + if blocks is not None: + self.block_pool.free_blocks(reversed(blocks)) + + self.num_cached_block.pop(request_id, None) @abstractmethod - def _find_longest_cache_hit( + def find_longest_cache_hit( self, block_hashes: list[BlockHashType], - computed_blocks: Optional[list[KVCacheBlock]], - ) -> list[KVCacheBlock]: + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None, + ) -> list[GroupedKVCacheBlock]: """ # TODO: update comment for multiple calls Get the longest cache hit prefix of the blocks. If no cache hit is @@ -97,8 +150,8 @@ def _find_longest_cache_hit( raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed blocks should be replaced by null_block. Return the removed blocks in @@ -116,10 +169,11 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class FullAttentionManager(SpecializedManager): - def _find_longest_cache_hit( - self, block_hashes: list[BlockHashType], - computed_blocks: Optional[list[KVCacheBlock]] - ) -> list[KVCacheBlock]: + def find_longest_cache_hit( + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None + ) -> list[GroupedKVCacheBlock]: if computed_blocks is None: computed_blocks = [] for block_hash in block_hashes: @@ -127,7 +181,7 @@ def _find_longest_cache_hit( # not in the cached_block_hash_to_id, the following block hashes # are not computed yet for sure. if cached_block := self.block_pool.get_cached_block( - block_hash, self.kv_cache_group_id): + block_hash, self.kv_cache_manager_id): computed_blocks.append(cached_block) else: break @@ -136,17 +190,18 @@ def _find_longest_cache_hit( del computed_blocks[len(block_hashes):] return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # No need to remove blocks for full attention. - return [] + pass class SlidingWindowManager(SpecializedManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - kv_cache_group_id: int): - super().__init__(kv_cache_spec, block_pool, kv_cache_group_id) + kv_cache_manager_id: int, num_kv_cache_groups: int): + super().__init__(kv_cache_spec, block_pool, kv_cache_manager_id, + num_kv_cache_groups) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -154,10 +209,11 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, (kv_cache_spec.sliding_window - 1), self.block_size) self._null_block = block_pool.null_block - def _find_longest_cache_hit( - self, block_hashes: list[BlockHashType], - computed_blocks: Optional[list[KVCacheBlock]] - ) -> list[KVCacheBlock]: + def find_longest_cache_hit( + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None + ) -> list[GroupedKVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + @@ -165,7 +221,11 @@ def _find_longest_cache_hit( # which is good for low cache hit rate scenarios. if computed_blocks is None: num_contiguous_blocks = 0 - computed_blocks = [self._null_block] * len(block_hashes) + computed_blocks = [ + GroupedKVCacheBlock.from_kv_cache_blocks( + tuple([self._null_block] * self.num_kv_cache_groups)) + for _ in range(len(block_hashes)) + ] else: if len(computed_blocks) == len(block_hashes): return computed_blocks @@ -179,7 +239,7 @@ def _find_longest_cache_hit( # Search from right to left and early stop when a match is found. for i in range(len(block_hashes) - num_contiguous_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( - block_hashes[i], self.kv_cache_group_id): + block_hashes[i], self.kv_cache_manager_id): computed_blocks[i] = cached_block num_contiguous_blocks += 1 if (num_contiguous_blocks @@ -196,23 +256,25 @@ def _find_longest_cache_hit( del computed_blocks[num_contiguous_blocks:] return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size + blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] + removed_blocks: list[GroupedKVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: + if blocks[i].blocks[0] == self._null_block: # If the block is already a null block, the blocks before it # should also have been set to null blocks by the previous calls # to this function. break removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - return removed_blocks + blocks[i] = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple([self._null_block] * self.num_kv_cache_groups)) + self.block_pool.free_blocks(removed_blocks) spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { @@ -222,7 +284,9 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], def get_specialized_manager(kv_cache_spec: KVCacheSpec, block_pool: BlockPool, - kv_cache_group_id: int) -> SpecializedManager: + kv_cache_manager_id: int, + num_kv_cache_groups: int) -> SpecializedManager: manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool, kv_cache_group_id) + manager = manager_class(kv_cache_spec, block_pool, kv_cache_manager_id, + num_kv_cache_groups) return manager From 2c81fe6e75fabc98f3670bc9697164ab731e9fee Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 24 Apr 2025 10:11:58 -0700 Subject: [PATCH 12/34] small update Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 17 +++++++---------- vllm/v1/core/kv_cache_manager.py | 4 ++-- vllm/v1/core/kv_cache_utils.py | 8 +++++++- vllm/v1/core/specialized_manager.py | 5 ----- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 88897b84ad86..796ff98554c0 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -41,7 +41,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool, # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {manager_id: {block_hash: {block ID: list[block]}}}. A cached block is + # {manager_id: {block_hash: {block ID: GroupedKVCacheBlock}}}. A cached block is # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. @@ -219,16 +219,13 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: manager_id = block.manager_id if block_hash and block_hash in self.cached_block_hash_to_block[ manager_id]: - for b in self.cached_block_hash_to_block[manager_id][block_hash][ - block.block_id].blocks: - b.reset_hash() - del self.cached_block_hash_to_block[manager_id][block_hash][ - block.block_id] - - if len(self.cached_block_hash_to_block[manager_id] - [block_hash]) == 0: + cached_blocks = ( + self.cached_block_hash_to_block[manager_id][block_hash]) + assert block.block_id in cached_blocks + cached_blocks[block.block_id].reset_hash() + del cached_blocks[block.block_id] + if len(cached_blocks) == 0: del self.cached_block_hash_to_block[manager_id][block_hash] - return True return False diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5962d953c04c..9cb47c9e60b1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -44,8 +44,6 @@ def __init__( caching_hash_algo: str = "builtin", log_stats: bool = False, ) -> None: - # TODO: adjust the name for item in one group, list of items in all - # groups, and reduced item for all groups. self.kv_cache_config = kv_cache_config self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len @@ -61,6 +59,8 @@ def __init__( self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + # the kv cache groups managed by the each manager + # manager_id -> list[kv_cache_group_id] self.manager_to_group = self.generate_group_manager_map() self.num_specialized_managers = len(self.manager_to_group) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6d4e2e4b7251..dfb4f3cbf96e 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -848,7 +848,8 @@ def remove_last_block_hash_for_divisible_prompt_length( block_hashes[block_size].append(block_hash) -# KVCacheBlocks for the same set of token of groups managed by the same manager +# KVCacheBlocks for the same block of all kv cache groups with the same kv cache +# spec (and belongs to the same manager) @dataclass class GroupedKVCacheBlock: blocks: tuple[KVCacheBlock, ...] @@ -860,3 +861,8 @@ def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): return GroupedKVCacheBlock(blocks=blocks, block_hash=blocks[0].block_hash, block_id=blocks[0].block_id) + + def reset_hash(self): + for block in self.blocks: + block.reset_hash() + self.block_hash = None diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 661f440b89eb..e5e180b3a77f 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -67,10 +67,6 @@ def get_num_needed_blocks( def allocate_new_blocks(self, request_id: str, num_tokens: int) -> list[GroupedKVCacheBlock]: - """ - return [group_id][block_of_that_group] - """ - # TODO: group? num_required_blocks = cdiv(num_tokens, self.block_size) num_new_blocks = max( num_required_blocks - len(self.req_to_blocks[request_id]), 0) @@ -131,7 +127,6 @@ def find_longest_cache_hit( computed_blocks: Optional[list[GroupedKVCacheBlock]] = None, ) -> list[GroupedKVCacheBlock]: """ - # TODO: update comment for multiple calls Get the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. # TODO: add notes for computed_blocks will not be longer than block_hashes. From 42a82444767fd02f2d09bfc4a9595e56139d3d87 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 24 Apr 2025 10:19:25 -0700 Subject: [PATCH 13/34] small fix Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 3 ++- vllm/v1/engine/core.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9cb47c9e60b1..a7842a556a6c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -429,13 +429,14 @@ def generate_group_manager_map(self) -> list[list[int]]: assert sorted(type_ids) == type_ids, "type_ids must be sorted" manager_to_group: list[list[int]] = [] for i, type_id in enumerate(type_ids): - if type_id == 0: + if i == 0: manager_to_group.append([i]) else: if type_id == type_ids[i - 1]: manager_to_group[-1].append(i) else: manager_to_group.append([i]) + print("manager_to_group", manager_to_group) return manager_to_group def to_block_ids(self, kv_cache_blocks: KVCacheBlocks) -> list[list[int]]: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 572e052cdcc2..c2b0dad193e6 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -150,6 +150,7 @@ def _initialize_kv_caches( num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 scheduler_kv_cache_config = kv_cache_configs[0] + print("kv_cache_config", scheduler_kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) From 6493e5e0a3aa780ee19fa25729754a2bb7d3e308 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 03:15:01 -0700 Subject: [PATCH 14/34] fix gemma Signed-off-by: Chen Zhang --- vllm/model_executor/models/gemma2.py | 4 ++-- vllm/model_executor/models/gemma3.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 35e698fca410..7fb2e9948c06 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -145,8 +145,8 @@ def __init__(self, # reference: # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa layer_idx = extract_layer_index(prefix) - use_sliding_window = (layer_idx % 2 == 0 and - config.interleaved_sliding_window is not None) + use_sliding_window = (layer_idx % 2 == 0 and getattr( + config, "interleaved_sliding_window", None) is not None) sliding_window = config.interleaved_sliding_window if \ use_sliding_window else None self.attn = Attention(self.num_heads, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 34957300ff9a..4e0d4f84ca6b 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -146,7 +146,9 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.is_sliding = (getattr( + config, "interleaved_sliding_window", None) is not None and bool( + (layer_idx + 1) % config.sliding_window_pattern)) # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. From c512bc5b1c21c3cbbd57d676b824c0c8e81b640f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 03:53:40 -0700 Subject: [PATCH 15/34] update attn backends Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 17 +++++----- vllm/v1/attention/backends/flashinfer.py | 37 ++++++++++++---------- vllm/v1/attention/backends/mla/common.py | 30 +++++++++--------- vllm/v1/attention/backends/mla/flashmla.py | 8 ++--- vllm/v1/worker/gpu_model_runner.py | 3 +- 5 files changed, 49 insertions(+), 46 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d836c6438f8a..d53b3b3eac4b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) @@ -280,8 +280,8 @@ def make_local_attention_virtual_batches( class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec, - persistent_block_table: BlockTable): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): model_config = runner.model_config self.runner = runner @@ -293,7 +293,7 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec, self.headdim = model_config.get_head_size() self.page_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.persistent_block_table = persistent_block_table + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -307,7 +307,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, non_blocking=True) seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) - block_table = self.persistent_block_table + block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() @@ -333,7 +333,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = None if self.runner.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table = make_local_attention_virtual_batches( + virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], @@ -357,14 +357,13 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, local_seqused_k=local_seqused_k, - local_block_table=virt_block_table, + local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, local_scheduler_metadata=local_scheduler_metadata, ) - # use_cascade = common_prefix_len > 0 - use_cascade = False + use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9e2440c5d477..a11c846d302e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,9 +15,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -206,7 +208,8 @@ def __post_init__(self): class FlashInferMetadataBuilder: - def __init__(self, runner: GPUModelRunner): + def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -216,7 +219,9 @@ def __init__(self, runner: GPUModelRunner): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = runner.vllm_config + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -403,15 +408,14 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) - page_size = self.runner.block_size + page_size = self.kv_cache_spec.block_size device = self.runner.device qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( self.runner.device, non_blocking=True) seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, non_blocking=True) - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -427,12 +431,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], dtype=torch.int32, device=device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_page_indices = block_table_tensor[ + 0, :num_common_kv_blocks] shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. - block_table = block_table[:, num_common_kv_blocks:] + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_bounds -= num_common_kv_blocks else: shared_qo_indptr = None @@ -440,11 +445,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices = None shared_kv_last_page_len = None - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -463,9 +468,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.runner.num_query_heads, - num_kv_heads=self.runner.num_kv_heads, - head_dim=self.runner.head_size, + num_qo_heads=self.kv_cache_spec.num_query_heads, + num_kv_heads=self.kv_cache_spec.num_kv_heads, + head_dim=self.kv_cache_spec.head_size, page_size=page_size, data_type=self.runner.kv_cache_dtype, q_data_type=self.runner.dtype, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 71b2558faf0f..264c7aff222b 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# type: ignore """ This file implements common components for MLA implementations. @@ -205,7 +204,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version @@ -344,8 +343,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", - kv_cache_spec: KVCacheSpec, - persistent_block_table: BlockTable, + kv_cache_spec: AttentionSpec, + block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -358,10 +357,11 @@ def __init__(self, runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: - self.page_size = self.runner.block_size + self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -388,7 +388,7 @@ def __init__(self, device=runner.device, ) self.page_size = kv_cache_spec.block_size - self.persistent_block_table = persistent_block_table + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -451,10 +451,11 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor): return MLACommonDecodeMetadata( input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, ) @@ -466,13 +467,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = (self.persistent_block_table.block_table. - get_device_tensor()[:num_reqs]) + block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( device, non_blocking=True) - slot_mapping = (self.persistent_block_table. - slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long()) + slot_mapping = ( + self.block_table.slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long()) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() @@ -551,7 +551,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], - block_table=block_table[reqs_start:, ...], + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -561,7 +561,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, if self._num_decodes > 0: decode_metadata = self._build_decode( input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decodes, ...], + block_table_tensor=block_table_tensor[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2260761abed8..347f30c9e3ec 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -55,14 +55,14 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): def __init__(self, runner, kv_cache_spec: KVCacheSpec, - persistent_block_table: BlockTable): - super().__init__(runner, kv_cache_spec, persistent_block_table) + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -73,7 +73,7 @@ def _build_decode(self, input_positions: torch.Tensor, return FlashMLADecodeMetadata( input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9926f6c86e77..c98febac3a26 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1754,8 +1754,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_spec.dtype, kv_cache_spec.block_size, self.model_config.is_attention_free, - use_mla=(isinstance(kv_cache_spec, AttentionSpec) - and kv_cache_spec.use_mla), + use_mla=kv_cache_spec.use_mla, ) if attn_backend_i is None: error_msg = ( From 4ce3424cc4eb8ef2b92b11a288d8c5ec0231181c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 03:57:03 -0700 Subject: [PATCH 16/34] fix flashinfer type Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flashinfer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index a11c846d302e..7224ec96d1c4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# type: ignore """Attention layer with FlashInfer.""" from __future__ import annotations From d17843eacd8ce807cf0a7110fda4537ebdbe97cc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 04:00:00 -0700 Subject: [PATCH 17/34] fix flashmla type Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/flashmla.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 347f30c9e3ec..e072f74b4978 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# type: ignore from dataclasses import dataclass from typing import Any, Optional @@ -16,7 +15,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -54,7 +53,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - def __init__(self, runner, kv_cache_spec: KVCacheSpec, + def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): super().__init__(runner, kv_cache_spec, block_table) From 47ec1a7d256d1db93358b8af486acb80810c1525 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 04:00:54 -0700 Subject: [PATCH 18/34] fix triton type Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/triton_mla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 83d2116aa81d..3bae676bc674 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# type: ignore from typing import Any, Optional import torch From 840675f606b81f956197be6316d88ecc64ecb946 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 04:27:19 -0700 Subject: [PATCH 19/34] clean up slidingwindowspec Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 33 ++++++++++++++++++++++-------- vllm/v1/kv_cache_interface.py | 7 ++----- vllm/v1/worker/gpu_model_runner.py | 9 ++++---- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index dfb4f3cbf96e..877cbd4b3549 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" -from contextlib import contextmanager import math import os from collections import defaultdict, deque from collections.abc import Sequence +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -557,6 +557,26 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") +def merge_layer_specs(layer_specs: list[KVCacheSpec]) -> KVCacheSpec: + """ + Merge a list of KVCacheSpec objects into a single KVCacheSpec object. + """ + assert all(layer_spec.type_id == layer_specs[0].type_id + for layer_spec in layer_specs[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") + layer_spec = layer_specs[0] + if isinstance(layer_spec, FullAttentionSpec): + for spec in layer_specs[1:]: + assert isinstance(spec, FullAttentionSpec) + if spec.sliding_window is not None: + if layer_spec.sliding_window is None: + layer_spec.sliding_window = spec.sliding_window + else: + assert layer_spec.sliding_window == spec.sliding_window + return layer_spec + + def create_kv_cache_group_specs( kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: @@ -577,12 +597,9 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_spec = kv_cache_spec[layer_names_one_group[0]] - assert all( - kv_cache_spec[layer_name] == layer_spec - for layer_name in layer_names_one_group[1:]), ( - "All layers in the same KV cache group must share the same " - "KVCacheSpec.") + layer_spec = merge_layer_specs([ + kv_cache_spec[layer_name] for layer_name in layer_names_one_group + ]) kv_cache_groups.append( KVCacheGroupSpec(layer_names_one_group, layer_spec)) return kv_cache_groups @@ -760,7 +777,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, - compute_as_sliding_window=True, + sliding_window=spec.sliding_window, ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 610e1b85ec05..b44149f79896 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Optional import torch @@ -9,9 +9,6 @@ from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size -if TYPE_CHECKING: - pass - logger = init_logger(__name__) @@ -77,7 +74,7 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): # TODO: add note - compute_as_sliding_window: bool = False + sliding_window: Optional[int] = None @property def type_id(self) -> str: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c98febac3a26..555fe016586a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -695,9 +695,9 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.compute_as_sliding_window)) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1712,8 +1712,7 @@ def _setup_kv_cache_shapes( assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, - (FullAttentionSpec, SlidingWindowSpec)): + if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) From ffcbde8b25a68fc6a2db7d49218d7038c452ac8b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 04:35:44 -0700 Subject: [PATCH 20/34] clean up block table Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 29 ++++++++--------------------- vllm/v1/worker/gpu_input_batch.py | 6 +++--- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 1e8e93ea3ee1..e8d3b823ab8d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -102,18 +102,22 @@ def get_numpy_array(self) -> np.ndarray: P = ParamSpec("P") -class MultiLayerBlockTable: +class MultiGroupBlockTable: move_row: Callable[P, None] swap_row: Callable[P, None] commit: Callable[P, None] clear: Callable[P, None] - append_row: Callable[Concatenate[list[int], P], None] - add_row: Callable[Concatenate[list[int], P], None] + append_row: Callable[Concatenate[list[list[int]], P], None] + add_row: Callable[Concatenate[list[list[int]], P], None] - def __init__(self, max_num_reqs: int, max_num_blocks_per_req: list[int], + def __init__(self, max_num_reqs: int, max_model_len: int, max_num_tokens: int, pin_memory: bool, device: torch.device, kv_cache_config: KVCacheConfig) -> None: + max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] self.block_tables = [ BlockTable(max_num_reqs, max_num_blocks_per_req[i], max_num_tokens, pin_memory, device) @@ -147,20 +151,3 @@ def broadcast_func(block_ids: list[int], *args: P.args, def __getitem__(self, idx: int) -> "BlockTable": return self.block_tables[idx] - - -def initialize_block_table( - max_num_reqs: int, - max_model_len: int, - max_num_tokens: int, - pin_memory: bool, - device: torch.device, - kv_cache_config: KVCacheConfig, -) -> MultiLayerBlockTable: - max_num_blocks_per_req = [ - cdiv(max_model_len, g.kv_cache_spec.block_size) - for g in kv_cache_config.kv_cache_groups - ] - return MultiLayerBlockTable(max_num_reqs, max_num_blocks_per_req, - max_num_tokens, pin_memory, device, - kv_cache_config) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5afc6e7757f1..19ef82775dd3 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,7 +15,7 @@ from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import initialize_block_table +from vllm.v1.worker.block_table import MultiGroupBlockTable _SAMPLING_EPS = 1e-5 @@ -99,7 +99,7 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = initialize_block_table( + self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_tokens=max_num_tokens, @@ -267,7 +267,7 @@ def add_request( self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(request.block_ids, req_index) # type: ignore + self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params if sampling_params.sampling_type == SamplingType.GREEDY: From 4eebce7e7b726ea9ce96de4aa798f827ff9fe81e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 06:40:09 -0700 Subject: [PATCH 21/34] clean up runner (WIP) Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 54 +++++++++++++++++++----------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 555fe016586a..50d03389b63e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -125,10 +125,9 @@ def __init__( # Sampler self.sampler = Sampler() - # Lazy initialization + # Lazy initializations # self.model: nn.Module # Set after load_model - # init in initialize_kv_cache - self.kv_caches: list[torch.Tensor] = [] + # Initialized in initialize_kv_cache self.kv_cache_config = cast(KVCacheConfig, None) self.attn_backends: list[type[AttentionBackend]] = [] self.attn_metadata_builders: list[type[AttentionMetadataBuilder]] = [] @@ -240,6 +239,32 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + + Returns: + True if the batch was reordered, False otherwise. + """ + batch_reordered = self.attn_metadata_builders[0].reorder_batch( + self.input_batch, scheduler_output) + + # For models with multiple KV cache groups, the groups should agree on + # the same order of requests. We ensure this by only allowing the first + # group to reorder the batch. + for kv_cache_group_id in range( + 1, len(self.kv_cache_config.kv_cache_groups)): + assert not self.attn_metadata_builders[ + kv_cache_group_id].reorder_batch(self.input_batch, + scheduler_output) + return batch_reordered + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -395,9 +420,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row( - req_data.new_block_ids, # type: ignore - req_index) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) @@ -437,17 +461,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - batch_reordered = self.attn_metadata_builders[0].reorder_batch( - self.input_batch, scheduler_output) - - for kv_cache_group_id in range( - 1, len(self.kv_cache_config.kv_cache_groups)): - assert not self.attn_metadata_builders[ - kv_cache_group_id].reorder_batch(self.input_batch, - scheduler_output) + batch_reordered = self._may_reorder_batch(scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -515,6 +529,7 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) + # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): block_size = kv_cache_group_spec.kv_cache_spec.block_size @@ -565,6 +580,8 @@ def _prepare_inputs( non_blocking=True) attn_metadata: dict[str, FlashAttentionMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -1734,8 +1751,7 @@ def initialize_kv_cache_tensors(self, kv_cache_raw_tensors) bind_kv_cache( kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + self.vllm_config.compilation_config.static_forward_context, []) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # TODO: docstring From 4380fa6f3a5e19fdddb51ee415ac3a08fe86929e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 07:13:57 -0700 Subject: [PATCH 22/34] add notes Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 50d03389b63e..e7413d77e2fa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -614,7 +614,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - # TODO: add note for attn_metadata_i + # TODO: confirm which attn_metadata should be used logits_indices = attn_metadata_i.query_start_loc[1:] - 1 spec_decode_metadata = None else: @@ -1239,6 +1239,7 @@ def execute_model( target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] + # TODO: confirm which attn_metadata should be used target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1253,6 +1254,7 @@ def execute_model( dtype=torch.int32, device=self.device, ) + # TODO: confirm which attn_metadata should be used cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, @@ -1260,6 +1262,7 @@ def execute_model( target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] + # TODO: confirm which attn_metadata should be used target_slot_mapping = attn_metadata.slot_mapping[token_indices] draft_token_ids = self.drafter.propose( @@ -1269,6 +1272,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, + # TODO: confirm which attn_metadata should be used block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, ) From b567c5625d8a7ad25f97bca9f4327efefac94bd7 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 08:54:03 -0700 Subject: [PATCH 23/34] clean up attn_metadata read in runner Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 11 +++++----- vllm/v1/attention/backends/flashinfer.py | 10 ++++----- vllm/v1/attention/backends/mla/common.py | 10 ++++----- vllm/v1/attention/backends/utils.py | 13 ++++++++++++ vllm/v1/spec_decode/eagle.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 27 ++++++++++++++++-------- 6 files changed, 48 insertions(+), 25 deletions(-) create mode 100644 vllm/v1/attention/backends/utils.py diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d53b3b3eac4b..5e619bd797c6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,6 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, @@ -300,13 +301,11 @@ def reorder_batch(self, input_batch: "InputBatch", return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + seq_lens = common_attn_metadata.seq_lens + query_start_loc = common_attn_metadata.query_start_loc block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7224ec96d1c4..ca11c8d07539 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,6 +17,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -403,16 +404,15 @@ def _plan(self, attn_metadata: FlashInferMetadata): ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) page_size = self.kv_cache_spec.block_size device = self.runner.device - qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + qo_indptr = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 264c7aff222b..264c29464bf3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,6 +204,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version @@ -460,7 +461,8 @@ def _build_decode(self, input_positions: torch.Tensor, ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -468,16 +470,14 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # it blocks on all previous kernels. device = self.runner.device block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) slot_mapping = ( self.block_table.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long()) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + seq_lens = common_attn_metadata.seq_lens + query_start_loc = common_attn_metadata.query_start_loc prefill_metadata = None if self._num_prefills > 0: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 000000000000..b9d153da0a86 --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + + +@dataclass +class CommonAttentionMetadata: + """ + Metadata that are same for different layer types. + """ + query_start_loc: torch.Tensor + seq_lens: torch.Tensor diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 95f0c067d406..222c1d0e131a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -224,6 +224,8 @@ def load_model(self, target_model: nn.Module) -> None: self.model = EagleLlamaForCausalLM( model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) + # TODO: implement it + self.attn_layer_name = "TODO" self.model.load_weights( loader.get_all_weights( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e7413d77e2fa..4860b8973241 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,6 +32,7 @@ GiB_bytes, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheNewTensor, @@ -568,6 +569,12 @@ def _prepare_inputs( # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( @@ -602,7 +609,8 @@ def _prepare_inputs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len)) + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -615,7 +623,7 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. # TODO: confirm which attn_metadata should be used - logits_indices = attn_metadata_i.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -1043,7 +1051,7 @@ def execute_model( num_input_tokens = num_scheduled_tokens for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: - # TODO: notes for use layer_names[0] + # TODO: merge https://github.com/vllm-project/vllm/pull/17193 layer_name = kv_cache_group_spec.layer_names[0] attn_metadata[layer_name].num_input_tokens = num_input_tokens @@ -1230,7 +1238,7 @@ def execute_model( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. # We need to slice token_ids, positions, and hidden_states @@ -1240,8 +1248,8 @@ def execute_model( target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] # TODO: confirm which attn_metadata should be used - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1256,14 +1264,15 @@ def execute_model( ) # TODO: confirm which attn_metadata should be used cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] # TODO: confirm which attn_metadata should be used - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -1273,7 +1282,7 @@ def execute_model( next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, # TODO: confirm which attn_metadata should be used - block_table=attn_metadata.block_table, + block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() From 710d68e13713544e1fe6f8651cb3f370b9976566 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 09:00:18 -0700 Subject: [PATCH 24/34] reorder attn args Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/attention/backends/mla/triton_mla.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5e619bd797c6..718fb3b06662 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -304,8 +304,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - seq_lens = common_attn_metadata.seq_lens query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 264c29464bf3..90464b9073a3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -470,6 +470,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # it blocks on all previous kernels. device = self.runner.device block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) + query_start_loc = common_attn_metadata.query_start_loc slot_mapping = ( self.block_table.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long()) @@ -477,7 +478,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, device, non_blocking=True).long() seq_lens = common_attn_metadata.seq_lens - query_start_loc = common_attn_metadata.query_start_loc prefill_metadata = None if self._num_prefills > 0: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3bae676bc674..8e7e4f10b81b 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 + from typing import Any, Optional import torch From 2b8ffc443de6ec92ad487f90c2114d51b850a5ce Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 09:06:26 -0700 Subject: [PATCH 25/34] rename max_num_tokens Signed-off-by: Chen Zhang --- vllm/v1/worker/block_table.py | 14 +++++++------- vllm/v1/worker/gpu_input_batch.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index e8d3b823ab8d..bbd536f15a3d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -18,13 +18,13 @@ def __init__( self, max_num_reqs: int, max_num_blocks_per_req: int, - max_num_tokens: int, # TODO + max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req - self.max_num_tokens = max_num_tokens + self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -42,7 +42,7 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) @@ -112,15 +112,15 @@ class MultiGroupBlockTable: add_row: Callable[Concatenate[list[list[int]], P], None] def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_tokens: int, pin_memory: bool, device: torch.device, - kv_cache_config: KVCacheConfig) -> None: + max_num_batched_tokens: int, pin_memory: bool, + device: torch.device, kv_cache_config: KVCacheConfig) -> None: max_num_blocks_per_req = [ cdiv(max_model_len, g.kv_cache_spec.block_size) for g in kv_cache_config.kv_cache_groups ] self.block_tables = [ - BlockTable(max_num_reqs, max_num_blocks_per_req[i], max_num_tokens, - pin_memory, device) + BlockTable(max_num_reqs, max_num_blocks_per_req[i], + max_num_batched_tokens, pin_memory, device) for i in range(len(kv_cache_config.kv_cache_groups)) ] # For methods that just pass the arguments to each BlockTable. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 19ef82775dd3..b706a47f6ad4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -60,7 +60,7 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_tokens: int, + max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, @@ -102,7 +102,7 @@ def __init__( self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, - max_num_tokens=max_num_tokens, + max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, kv_cache_config=kv_cache_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4860b8973241..d06c0be1a78d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1816,7 +1816,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, - max_num_tokens=self.max_num_tokens, + max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), From b50aa146781eebdec939be8ab1e53e02233c3968 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 09:11:33 -0700 Subject: [PATCH 26/34] remove fixed TODO Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d06c0be1a78d..3357d5be50da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -622,7 +622,6 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - # TODO: confirm which attn_metadata should be used logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: @@ -1247,7 +1246,6 @@ def execute_model( target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] - # TODO: confirm which attn_metadata should be used target_slot_mapping = eagle_attn_metadata.slot_mapping cu_num_tokens = eagle_attn_metadata.query_start_loc else: @@ -1262,7 +1260,6 @@ def execute_model( dtype=torch.int32, device=self.device, ) - # TODO: confirm which attn_metadata should be used cu_num_tokens, token_indices = self.drafter.prepare_inputs( eagle_attn_metadata.query_start_loc, num_rejected_tokens, @@ -1270,7 +1267,6 @@ def execute_model( target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] - # TODO: confirm which attn_metadata should be used target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] @@ -1281,7 +1277,6 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - # TODO: confirm which attn_metadata should be used block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) From 136a54c0b1689d059e56ae5839c775f38a854524 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 09:13:53 -0700 Subject: [PATCH 27/34] fix Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3357d5be50da..131f4cdad220 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1063,15 +1063,6 @@ def execute_model( else: mm_embeds = [] - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) From 765d9ed22af74c8dc4908575c57d1bb4eb4aefb1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 09:20:06 -0700 Subject: [PATCH 28/34] add note Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index b44149f79896..449afdacde00 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -73,7 +73,10 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): - # TODO: add note + # Some layers may be regarded as full attention layers in KV cache manager ( + # blocks are allocated for all tokens), while computed as sliding window + # attention. In this case, we use FullAttentionSpec and record the + # sliding window size. sliding_window: Optional[int] = None @property From 216a079abcc521c255c19ac673a2fd19d229b0dc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 20:18:58 -0700 Subject: [PATCH 29/34] group partition strategy Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 877cbd4b3549..666028d168a9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" -import math import os from collections import defaultdict, deque from collections.abc import Sequence @@ -706,23 +705,28 @@ def _get_kv_cache_config_uniform_page_size( The generated KVCacheConfig """ # Group all layers by type_id. - # E.g., 2 full attention layers and 4 sliding window attention layers, - # -> (full.0, full.1), (sw.0, sw.1, sw.2, sw.3). + # E.g., 2 full attention layers and 3 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2). same_type_layers: dict[str, list[str]] = defaultdict(list) for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec.type_id].append(layer_name) - # Split each group into smaller groups, to make the number of layers in - # each group identical. - # E.g., (full.0, full.1), (sw.0, sw.1, sw.2, sw.3), group_size_gcd is 2, + # Split each group into smaller groups, to make the number of layers in each + # group identical. Add padding to the last group of each type if necessary. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) # split to 3 groups with 2 layers each: - # (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3). - group_size_gcd = math.gcd( - *[len(layers) for layers in same_type_layers.values()]) + # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + group_size = min([len(layers) for layers in same_type_layers.values()]) grouped_layers = [] for layers in same_type_layers.values(): - for i in range(0, len(layers), group_size_gcd): - grouped_layers.append(layers[i:i + group_size_gcd]) + num_padding_layers = len(layers) % group_size + if num_padding_layers > 0: + logger.warning( + "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa + num_padding_layers, + num_padding_layers / len(layers) * 100) + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i:i + group_size]) # Divide the available memory equally among all layers in the first group. # The memory layout in the example will be: @@ -738,12 +742,12 @@ def _get_kv_cache_config_uniform_page_size( # Reuse the KV cache tensors of the first group for the other groups. # The memory layout in the example will be: # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 - # full.1, sw.1, sw.3: share another Tensor with size=available_memory//2 + # full.1, sw.1: share another Tensor with size=available_memory//2 # Layers of different groups have different block table, so they will # use different parts of the shared Tensor. for layers in grouped_layers[1:]: - for layer_name, layer_name_first_group in zip(layers, - grouped_layers[0]): + for layer_name, layer_name_first_group in zip( + layers, grouped_layers[0][:len(layers)]): kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( reused_layer_name=layer_name_first_group) From 37c449435bd473095cc0d20762531b81e0dffd7c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 22:59:00 -0700 Subject: [PATCH 30/34] support eagle Signed-off-by: Chen Zhang --- vllm/v1/spec_decode/eagle.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 275bd8e34662..9ac03f5c82ca 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -215,6 +215,8 @@ def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) + target_attn_layer_names = set( + self.vllm_config.compilation_config.static_forward_context.keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -234,8 +236,12 @@ def load_model(self, target_model: nn.Module) -> None: self.model = Eagle3LlamaForCausalLM( model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) - # TODO: implement it - self.attn_layer_name = "TODO" + + draft_attn_layer_names = ( + self.vllm_config.compilation_config.static_forward_context.keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = iter(draft_attn_layer_names).__next__() loaded_weights = self.model.load_weights( loader.get_all_weights( From 84280fc0b016d32859fa705650203f570f68fd76 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 23:34:42 -0700 Subject: [PATCH 31/34] get a specific type of layer from forward context Signed-off-by: Chen Zhang --- vllm/attention/backends/flashinfer.py | 6 ++---- vllm/config.py | 13 +++++++++++++ vllm/v1/attention/backends/flashinfer.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 14 ++++---------- vllm/v1/worker/tpu_model_runner.py | 7 +++---- 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 718b15e58785..ce7ab7b176e5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -37,7 +37,7 @@ is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -128,12 +128,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/config.py b/vllm/config.py index 0ac3cc46b063..67b18f6d2e01 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4075,3 +4075,16 @@ def assert_hashable(text): f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..2533faf9bec9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,7 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -81,12 +82,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_config(vllm_config, Attention) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7910481762ef..d1ac30f1660c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,13 +12,12 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, VllmConfig, get_layers_from_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1736,17 +1735,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9cb0dbe8b5e..c183b04cc644 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -430,11 +430,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( From 51ffeb68939bbc5ee0609f9605ae3b383421e045 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 23:38:22 -0700 Subject: [PATCH 32/34] fix Signed-off-by: Chen Zhang --- vllm/attention/backends/flashinfer.py | 4 ++-- vllm/config.py | 4 ++-- vllm/v1/attention/backends/flashinfer.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ce7ab7b176e5..1d78295a9781 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -37,7 +37,7 @@ is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig, get_layers_from_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -128,7 +128,7 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = get_layers_from_config(vllm_config, Attention) + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): diff --git a/vllm/config.py b/vllm/config.py index 67b18f6d2e01..6a3e37992e7d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4080,8 +4080,8 @@ def assert_hashable(text): T = TypeVar("T") -def get_layers_from_config(vllm_config: VllmConfig, - layer_type: type[T]) -> dict[str, T]: +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: return { layer_name: layer for layer_name, layer in diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 2533faf9bec9..bce446bd2b82 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,7 +15,7 @@ AttentionType) from vllm.attention.layer import Attention from vllm.config import (VllmConfig, get_current_vllm_config, - get_layers_from_config) + get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -82,7 +82,7 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = get_layers_from_config(vllm_config, Attention) + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1ac30f1660c..29104f70b5c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,7 +12,8 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig, get_layers_from_config +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture @@ -1735,7 +1736,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_config(self.vllm_config, Attention) + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c183b04cc644..20a3b60172c2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig, get_layers_from_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -430,7 +430,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_config(self.vllm_config, Attention) + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): From 1da28d96dca5d3db3ef537c53c7da7a2219b24d7 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 23:48:24 -0700 Subject: [PATCH 33/34] update eagle Signed-off-by: Chen Zhang --- vllm/v1/spec_decode/eagle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9ac03f5c82ca..a84627b847e4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,7 +4,9 @@ import triton import triton.language as tl -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader @@ -216,7 +218,7 @@ def load_model(self, target_model: nn.Module) -> None: target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) target_attn_layer_names = set( - self.vllm_config.compilation_config.static_forward_context.keys()) + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -238,7 +240,7 @@ def load_model(self, target_model: nn.Module) -> None: start_layer_id=target_layer_num).to(target_device) draft_attn_layer_names = ( - self.vllm_config.compilation_config.static_forward_context.keys() - + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) assert len(draft_attn_layer_names) == 1 self.attn_layer_name = iter(draft_attn_layer_names).__next__() From e5cb02e6755c4183eea37e693a808c13a45c9747 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 23:57:46 -0700 Subject: [PATCH 34/34] only enable cuda platform Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 666028d168a9..4254f66945b3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -801,7 +801,8 @@ def get_kv_cache_config(vllm_config: VllmConfig, The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if vllm_config.cache_config.disable_hybrid_allocator: + if (vllm_config.cache_config.disable_hybrid_allocator + or vllm_config.device_config.device.type != "cuda"): unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for