diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df934..b67c05bd7ac1 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + import torch from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) -from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager -from vllm.v1.kv_cache_interface import SlidingWindowSpec +from vllm.v1.core.single_type_kv_cache_manager import ( + ChunkedLocalAttentionManager, SlidingWindowManager) +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + SlidingWindowSpec) def get_sliding_window_manager(sliding_window_spec, block_pool): @@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): kv_cache_group_id=0) +def get_chunked_local_attention_manager(chunked_local_attention_spec, + block_pool): + return ChunkedLocalAttentionManager(chunked_local_attention_spec, + block_pool, + caching_hash_fn=lambda x: x, + kv_cache_group_id=0) + + +def test_chunked_local_attention_possible_cached_prefix(): + block_size = 2 + chunked_local_attention_spec = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + manager = get_chunked_local_attention_manager(chunked_local_attention_spec, + block_pool) + + def run_one_case(block_is_cached, tail_token, expect_length): + block_hash_list = [ + BlockHash(i, ()) for i in range(len(block_is_cached)) + ] + + block_pool.cached_block_hash_to_block.clear() + + # Mock the block pool with the cached blocks + for i, (block_hash, + is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + if is_cached: + block_pool.cached_block_hash_to_block[BlockHashWithGroupId( + block_hash, 0)] = { + i: block_pool.blocks[i + 10], + } + + computed_blocks = manager.find_longest_cache_hit( + block_hashes=block_hash_list, + max_length=len(block_hash_list) * block_size + tail_token, + kv_cache_group_ids=[0], + block_pool=block_pool, + kv_cache_spec=chunked_local_attention_spec, + use_eagle=False)[0] + assert len(computed_blocks) == expect_length + + assert all(block == block_pool.null_block + for block in computed_blocks[:(expect_length - 1) // 2]) + + run_one_case([True], 0, 1) + run_one_case([True], 1, 1) + run_one_case([True, False], 0, 2) + run_one_case([True, False], 1, 2) + run_one_case([True, True], 0, 2) + run_one_case([True, True], 1, 2) + run_one_case([True, True, False], 0, 2) + run_one_case([True, True, False], 1, 2) + run_one_case([True, True, True], 0, 3) + run_one_case([True, True, True], 1, 3) + run_one_case([True, True, True, False], 0, 4) + run_one_case([True, True, True, False], 1, 4) + run_one_case([random.choice([True, False])] * 8 + [True], 1, 9) + run_one_case([random.choice([True, False])] * 8 + [False], 1, 8) + run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10) + + def test_sliding_window_possible_cached_prefix(): block_size = 2 sliding_window_spec = SlidingWindowSpec( @@ -84,6 +162,58 @@ def run_one_case(block_is_cached, expect_length): ], 8) +def test_chunked_local_attention_remove_skipped_blocks(): + attention_spec = ChunkedLocalAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + + manager = get_chunked_local_attention_manager(attention_spec, block_pool) + + null_block_id = block_pool.null_block.block_id + + def id_to_block_table(ids) -> list[KVCacheBlock]: + return [ + KVCacheBlock(id_) + if id_ != null_block_id else block_pool.null_block for id_ in ids + ] + + def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): + for block, id_ in zip(block_table, ids): + if id_ == null_block_id: + assert block == block_pool.null_block + else: + assert block.block_id == id_ + + original_block_ids = [ + 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + ] + block_table = id_to_block_table(original_block_ids) + manager.req_to_blocks["test"] = block_table + + manager.remove_skipped_blocks("test", 0) + assert_block_id(block_table, original_block_ids) + + # For 4th token (0-indexed), token 0-3 is out of the local attention window. + manager.remove_skipped_blocks("test", 4) + assert_block_id(block_table, [null_block_id] * 2) + + # For 6th token (0-indexed), token 4 - 6 are in local attention window, + # token 0 - 3 are out, 2 blocks can be removed. + manager.remove_skipped_blocks("test", 6) + assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) + # For 12th token (0-indexed), + # token 0-11 are out, 6 block can be removed. + manager.remove_skipped_blocks("test", 12) + assert_block_id(block_table, [null_block_id] * 6) + + def test_sliding_window_remove_skipped_blocks(): sliding_window_spec = SlidingWindowSpec( block_size=2, @@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate(): cached_blocks_1) == 20 assert manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + + +def test_chunked_local_attention_get_num_blocks_to_allocate(): + block_size = 2 + attention_spec = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, # Placeholder value, not related to test result + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + manager = get_chunked_local_attention_manager(attention_spec, block_pool) + cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] + cached_blocks_2 = [block_pool.null_block for _ in range(5) + ] + [KVCacheBlock(i + 1) for i in range(5)] + + assert manager.get_num_blocks_to_allocate("1", 20 * block_size, + cached_blocks_1) == 20 + assert manager.get_num_blocks_to_allocate("2", 20 * block_size, + cached_blocks_2) == 15 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f9c2d4f49835..b28c20d7b54e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -143,6 +143,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args) self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype + self.use_irope = extra_impl_args.get("use_irope", False) # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant diff --git a/vllm/config.py b/vllm/config.py index f94c08c32536..6b84ab3921cd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4709,6 +4709,13 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.model_config is not None and \ + self.model_config.attention_chunk_size is not None and \ + self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + self.scheduler_config.disable_hybrid_kv_cache_manager = True def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d5b30ac685ac..a37bf2a7115b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -538,6 +538,7 @@ def use_cascade_attention( num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, + use_local_attention: bool, num_sms: int, ) -> bool: """Decide whether to use cascade attention. @@ -553,7 +554,7 @@ def use_cascade_attention( if common_prefix_len < 256: return False # Cascade attention is currently not supported with these variants. - if use_alibi or use_sliding_window: + if use_alibi or use_sliding_window or use_local_attention: return False # Too few queries. Probably not worth using cascade attention. # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b6a06b17bca2..65c3baa6784f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -120,6 +120,7 @@ def use_cascade_attention( num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, + use_local_attention: bool, num_sms: int, ) -> bool: return False diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6067a127e97f..35602fde9b06 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -11,7 +11,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats @@ -934,7 +935,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) has_sliding_window = any( isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) - if has_full_attention and has_sliding_window: + has_chunked_local_attention = any( + isinstance(spec, ChunkedLocalAttentionSpec) + for spec in kv_cache_spec.values()) + if has_full_attention and (has_sliding_window + or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -945,6 +950,15 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: use_mla=spec.use_mla, sliding_window=spec.sliding_window, ) + elif isinstance(spec, ChunkedLocalAttentionSpec): + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + attention_chunk_size=spec.attention_chunk_size, + ) if is_hybrid(kv_cache_spec): raise ValueError("Hybrid KV cache manager is disabled but failed to " @@ -968,7 +982,6 @@ def get_kv_cache_config( The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 1560406c9004..65a196e044ab 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -394,6 +394,129 @@ def get_num_common_prefix_blocks(self, request_id: str, return 0 +class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): + + def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, + block_pool: BlockPool, **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, **kwargs) + self.attention_chunk_size = kv_cache_spec.attention_chunk_size + self._null_block = block_pool.null_block + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + """ + For chunked local attention, we need to find the longest cache hit + prefix of the blocks that is not longer than `max_length`. The prefix + should be a common prefix hit for all the kv cache groups in + `kv_cache_group_ids`. If no cache hit is found, return an empty list. + note we mark as computed if the whole block is outside of the local + window, and set the block as null. Examples: + + 1. Attention chunk size of 8, block size of 4, max length of 15 + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return + [null, null, block 3], otherwise, we return [null, null] + + 2. Attention chunk size of 8, block size of 4, max length of 16 + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. + we return 4 blocks[null, null, null, null] + + Args: + block_hashes: The block hashes of the request. + max_length: The maximum length of the cache hit prefix. + kv_cache_group_ids: The ids of the kv cache groups. + block_pool: The block pool. + kv_cache_spec: The kv cache spec. + use_eagle: Whether to use eagle. + + Returns: + A list of cached blocks + """ + assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups") + assert use_eagle is False, ("Hybrid KV cache is not supported for " + + "eagle + chunked local attention.") + max_num_blocks = max_length // kv_cache_spec.block_size + if max_length > 0: + local_attention_start_idx = (max_length // + kv_cache_spec.attention_chunk_size * + kv_cache_spec.attention_chunk_size) + else: + local_attention_start_idx = 0 + # we marked blocks out of window as computed + # with null blocks, and blocks inside window based on cache lookup + # result [null] [null] ... [null] [hit block 1 (1st block contain + # last window)] [hit block 2] ... [hit block x] + local_attention_start_block_idx = (local_attention_start_idx // + kv_cache_spec.block_size) + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [block_pool.null_block] * local_attention_start_block_idx + for _ in range(len(kv_cache_group_ids))) + for i in range(local_attention_start_block_idx, max_num_blocks): + block_hash = block_hashes[i] + if cached_block := block_pool.get_cached_block( + block_hash, kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) + else: + break + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Remove the blocks that are no longer be in the chunked attention + # window and skipped during the attention computation. + + # [chunk 0][chunk 1]local_attention_start_idx ... current + # we computed previous number of chunks to get the idx of + # current chunk window starting offset, + # e.g. for computed 1024 tokens, the 1024th token (0 indexed) + # is in the second chunk, there are 1 prev chunk, the start idx + # is 1024. for 1023, it will be 0. + num_cached_block = self.num_cached_block.get(request_id, 0) + local_attention_start_idx = ( + num_computed_tokens + ) // self.attention_chunk_size * self.attention_chunk_size + first_useful_block_idx = local_attention_start_idx // self.block_size + if num_cached_block > 0: + # Make sure we don't delete the last cached block + first_useful_block_idx = min(first_useful_block_idx, + num_cached_block - 1) + # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> + # block 8, 372 (= 128 * 2 + 116) -> block 2 + blocks = self.req_to_blocks[request_id] + removed_blocks: list[KVCacheBlock] = [] + # we need to keep the last block to get the previous hash key + for i in range(first_useful_block_idx - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + self.block_pool.free_blocks(removed_blocks) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + cascade attention is not supported by chunked local attention. + """ + return 0 + + class MambaManager(SingleTypeKVCacheManager): @classmethod @@ -435,8 +558,8 @@ def allocate_new_blocks(self, request_id: str, spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, - ChunkedLocalAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, + ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6726709955f7..bec31a7a058d 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -87,6 +87,7 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): sliding_window: Optional[int] = None + attention_chunk_size: Optional[int] = None """ When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding @@ -105,6 +106,17 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size.") + @classmethod def merge(cls, specs: list[Self]) -> Self: """ @@ -114,14 +126,17 @@ def merge(cls, specs: list[Self]) -> Self: merged_spec = super().merge(specs) sliding_window = set(spec.sliding_window for spec in specs if spec.sliding_window is not None) - if len(sliding_window) == 0: - merged_spec.sliding_window = None - elif len(sliding_window) == 1: - merged_spec.sliding_window = sliding_window.pop() - else: - raise ValueError( - "All sliding window layers in the same KV cache group " - "must have the same window size.") + attention_chunk_size = set(spec.attention_chunk_size for spec in specs + if spec.attention_chunk_size is not None) + + merged_spec.sliding_window = cls.merge_window_sizes(sliding_window) + merged_spec.attention_chunk_size = ( + cls.merge_window_sizes(attention_chunk_size)) + assert ( + (merged_spec.sliding_window is not None) + + (merged_spec.attention_chunk_size is not None) <= 1 + ), ("Model with both sliding window layers and chunked local attention " + "layers is not supported.") return merged_spec @@ -129,16 +144,26 @@ def merge(cls, specs: list[Self]) -> Self: class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes - @property def type_id(self) -> str: return ( f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}" ) # noqa + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + + # During chunked prefill, we allocate KV cache for at most + # `self.attention_chunk_size` computed tokens plus the newly scheduled + # tokens. And we won't allocate KV cache for more than `max_model_len` + # tokens. + num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, + max_model_len) + + return cdiv(num_tokens, self.block_size) * self.page_size_bytes + @dataclass class SlidingWindowSpec(AttentionSpec): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c3eeb6c2e390..180be7155e9e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -862,6 +862,10 @@ def _compute_cascade_attn_prefix_len( use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or (isinstance(kv_cache_spec, FullAttentionSpec) and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -870,6 +874,7 @@ def _compute_cascade_attn_prefix_len( num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, use_sliding_window=use_sliding_window, + use_local_attention=use_local_attention, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -2637,6 +2642,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) + assert not use_local_attention, ( + "attention module can not be with ", + "both local attention and sliding window") elif use_local_attention: kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec( block_size=block_size,