-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Core] Support Local Chunked Attention for Hybrid KV Cache #19351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
7913145
38f90b5
78385bc
320ab71
1bbff13
076ab27
8b7b409
f2887f6
f7b6961
f60bea5
2c4c8c4
a234791
472cb24
58f82e2
23c7fba
5a0ccdf
5d83152
fcdee50
4655102
326c0e5
ad1169c
427a219
07f3353
1b41951
d48fb3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,17 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import random | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.v1.core.block_pool import BlockPool | ||
| from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, | ||
| KVCacheBlock) | ||
| from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager | ||
| from vllm.v1.kv_cache_interface import SlidingWindowSpec | ||
| from vllm.v1.core.single_type_kv_cache_manager import ( | ||
| ChunkedLocalAttentionManager, SlidingWindowManager) | ||
| from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, | ||
| SlidingWindowSpec) | ||
|
|
||
|
|
||
| def get_sliding_window_manager(sliding_window_spec, block_pool): | ||
|
|
@@ -17,6 +21,78 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): | |
| kv_cache_group_id=0) | ||
|
|
||
|
|
||
| def get_chunked_local_attention_manager(chunked_local_attention_spec, | ||
| block_pool): | ||
| return ChunkedLocalAttentionManager(chunked_local_attention_spec, | ||
| block_pool, | ||
| caching_hash_fn=lambda x: x, | ||
| kv_cache_group_id=0) | ||
|
|
||
|
|
||
| def test_chunked_local_attention_possible_cached_prefix(): | ||
| block_size = 2 | ||
| chunked_local_attention_spec = ChunkedLocalAttentionSpec( | ||
| block_size=block_size, | ||
| num_kv_heads=1, | ||
| head_size=1, | ||
| dtype=torch.float32, | ||
| attention_chunk_size=4, | ||
| use_mla=False, | ||
| ) | ||
|
|
||
| block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) | ||
| manager = get_chunked_local_attention_manager(chunked_local_attention_spec, | ||
| block_pool) | ||
|
|
||
| def run_one_case(block_is_cached, tail_token, expect_length): | ||
| block_hash_list = [ | ||
| BlockHash(i, ()) for i in range(len(block_is_cached)) | ||
| ] | ||
|
|
||
| block_pool.cached_block_hash_to_block.clear() | ||
|
|
||
| # Mock the block pool with the cached blocks | ||
| for i, (block_hash, | ||
| is_cached) in enumerate(zip(block_hash_list, block_is_cached)): | ||
| if is_cached: | ||
| block_pool.cached_block_hash_to_block[BlockHashWithGroupId( | ||
| block_hash, 0)] = { | ||
| i: block_pool.blocks[i + 10], | ||
| } | ||
|
|
||
| computed_blocks = manager.find_longest_cache_hit( | ||
| block_hashes=block_hash_list, | ||
| max_length=len(block_hash_list) * block_size + tail_token, | ||
| kv_cache_group_ids=[0], | ||
| block_pool=block_pool, | ||
| kv_cache_spec=chunked_local_attention_spec, | ||
| use_eagle=False)[0] | ||
| assert len(computed_blocks) == expect_length | ||
|
|
||
| assert all(block == block_pool.null_block | ||
| for block in computed_blocks[:(expect_length - 1) // 2]) | ||
|
|
||
| run_one_case([True], 0, 1) | ||
| run_one_case([True], 1, 1) | ||
| run_one_case([True, False], 0, 1) | ||
| run_one_case([True, False], 1, 2) | ||
| run_one_case([True, True], 0, 2) | ||
| run_one_case([True, True], 1, 2) | ||
| run_one_case([True, True, False], 0, 2) | ||
| run_one_case([True, True, False], 1, 2) | ||
| run_one_case([True, True, True], 0, 3) | ||
| run_one_case([True, True, True], 1, 3) | ||
| run_one_case([True, True, True, False], 0, 3) | ||
| run_one_case([True, True, True, False], 1, 4) | ||
| run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10) | ||
| run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 9) | ||
| run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10) | ||
| run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 8) | ||
| run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10) | ||
| run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 8) | ||
| run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10) | ||
|
|
||
|
|
||
| def test_sliding_window_possible_cached_prefix(): | ||
| block_size = 2 | ||
| sliding_window_spec = SlidingWindowSpec( | ||
|
|
@@ -84,6 +160,63 @@ def run_one_case(block_is_cached, expect_length): | |
| ], 8) | ||
|
|
||
|
|
||
| def test_chunked_local_attention_remove_skipped_blocks(): | ||
| attention_spec = ChunkedLocalAttentionSpec( | ||
| block_size=2, | ||
| num_kv_heads=1, | ||
| head_size=1, | ||
| dtype=torch.float32, | ||
| attention_chunk_size=4, | ||
| use_mla=False, | ||
| ) | ||
|
|
||
| block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) | ||
|
|
||
| manager = get_chunked_local_attention_manager(attention_spec, block_pool) | ||
|
|
||
| null_block_id = block_pool.null_block.block_id | ||
|
|
||
| def id_to_block_table(ids) -> list[KVCacheBlock]: | ||
| return [ | ||
| KVCacheBlock(id_) | ||
| if id_ != null_block_id else block_pool.null_block for id_ in ids | ||
| ] | ||
|
|
||
| def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): | ||
| for block, id_ in zip(block_table, ids): | ||
| if id_ == null_block_id: | ||
| assert block == block_pool.null_block | ||
| else: | ||
| assert block.block_id == id_ | ||
|
|
||
| original_block_ids = [ | ||
| 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 | ||
| ] | ||
| block_table = id_to_block_table(original_block_ids) | ||
| manager.req_to_blocks["test"] = block_table | ||
|
|
||
| manager.remove_skipped_blocks("test", 0) | ||
| assert_block_id(block_table, original_block_ids) | ||
|
|
||
| # 4 tokens are computed. no token is out of the local attention window. | ||
| manager.remove_skipped_blocks("test", 4) | ||
| assert_block_id(block_table, original_block_ids) | ||
|
||
|
|
||
| # 5 tokens are computed. token 0 is out of the local attention window. | ||
| # no block can be removed. | ||
| manager.remove_skipped_blocks("test", 5) | ||
| assert_block_id(block_table, [null_block_id]) | ||
|
|
||
| # 6 tokens are computed. token 4 - 5 are in local attention window, | ||
| # token 0 - 3 are out, 2 blocks can be removed. | ||
| manager.remove_skipped_blocks("test", 6) | ||
| assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) | ||
| # 11 tokens are computed. token 8 - 11 are in local attention window, | ||
| # token 0-7 are out, 4 block can be removed. | ||
| manager.remove_skipped_blocks("test", 11) | ||
| assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) | ||
|
|
||
|
|
||
| def test_sliding_window_remove_skipped_blocks(): | ||
| sliding_window_spec = SlidingWindowSpec( | ||
| block_size=2, | ||
|
|
@@ -172,3 +305,26 @@ def test_get_num_blocks_to_allocate(): | |
| cached_blocks_1) == 20 | ||
| assert manager.get_num_blocks_to_allocate("2", 20 * block_size, | ||
| cached_blocks_2) == 15 | ||
|
|
||
|
|
||
| def test_chunked_local_attention_get_num_blocks_to_allocate(): | ||
| block_size = 2 | ||
| attention_spec = ChunkedLocalAttentionSpec( | ||
| block_size=block_size, | ||
| num_kv_heads=1, | ||
| head_size=1, | ||
| dtype=torch.float32, | ||
| attention_chunk_size=4, # Placeholder value, not related to test result | ||
| use_mla=False, | ||
| ) | ||
|
|
||
| block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) | ||
| manager = get_chunked_local_attention_manager(attention_spec, block_pool) | ||
| cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] | ||
| cached_blocks_2 = [block_pool.null_block for _ in range(5) | ||
| ] + [KVCacheBlock(i + 1) for i in range(5)] | ||
|
|
||
| assert manager.get_num_blocks_to_allocate("1", 20 * block_size, | ||
| cached_blocks_1) == 20 | ||
| assert manager.get_num_blocks_to_allocate("2", 20 * block_size, | ||
| cached_blocks_2) == 15 | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -11,7 +11,8 @@ | |||
| from vllm.config import VllmConfig | ||||
| from vllm.logger import init_logger | ||||
| from vllm.utils import GiB_bytes, cdiv, sha256 | ||||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||||
| from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, | ||||
| FullAttentionSpec, KVCacheConfig, | ||||
| KVCacheGroupSpec, KVCacheSpec, | ||||
| KVCacheTensor, SlidingWindowSpec) | ||||
| from vllm.v1.metrics.stats import PrefixCacheStats | ||||
|
|
@@ -845,6 +846,7 @@ def _get_kv_cache_config_uniform_page_size( | |||
| # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 | ||||
| # full.1, sw.1: share another Tensor with size=available_memory//2 | ||||
| page_size = get_uniform_page_size(kv_cache_spec) | ||||
| # print(f"{page_size=}, {group_size=}") | ||||
|
||||
| # print(f"{page_size=}, {group_size=}") |
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,7 +7,8 @@ | |||||
| from vllm.utils import cdiv | ||||||
| from vllm.v1.core.block_pool import BlockPool | ||||||
| from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock | ||||||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, | ||||||
| from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, | ||||||
| FullAttentionSpec, KVCacheSpec, | ||||||
| SlidingWindowSpec) | ||||||
| from vllm.v1.request import Request | ||||||
|
|
||||||
|
|
@@ -384,15 +385,101 @@ def get_num_common_prefix_blocks(self, request_id: str, | |||||
| """ | ||||||
| NOTE(Chen): The prefix blocks are null blocks for sliding window layers. | ||||||
| So it's not correct to count ref_cnt like FullAttentionManager. Return | ||||||
| 0 here for correctness. Need to support cascade attention + sliding | ||||||
| window in the future. | ||||||
| 0 here for correctness. Need to support cascade attention in the future. | ||||||
| """ | ||||||
luccafong marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| return 0 | ||||||
|
|
||||||
|
|
||||||
| class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): | ||||||
|
|
||||||
| def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, | ||||||
| block_pool: BlockPool, **kwargs) -> None: | ||||||
| super().__init__(kv_cache_spec, block_pool, **kwargs) | ||||||
| self.attention_chunk_size = kv_cache_spec.attention_chunk_size | ||||||
| self._null_block = block_pool.null_block | ||||||
|
|
||||||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| @classmethod | ||||||
| def find_longest_cache_hit( | ||||||
luccafong marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cls, | ||||||
| block_hashes: list[BlockHash], | ||||||
| max_length: int, | ||||||
| kv_cache_group_ids: list[int], | ||||||
| block_pool: BlockPool, | ||||||
| kv_cache_spec: KVCacheSpec, | ||||||
| use_eagle: bool, | ||||||
| ) -> tuple[list[KVCacheBlock], ...]: | ||||||
| assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( | ||||||
| "ChunkedLocalAttentionManager can only be used for " + | ||||||
| "chunked local attentiongroups") | ||||||
luccafong marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| max_num_blocks = max_length // kv_cache_spec.block_size | ||||||
| if max_length > 0: | ||||||
| local_attention_start_idx = \ | ||||||
| (max_length-1) // kv_cache_spec.attention_chunk_size \ | ||||||
|
||||||
| * kv_cache_spec.attention_chunk_size | ||||||
| else: | ||||||
| local_attention_start_idx = 0 | ||||||
| # [ block 0, ..., block x(x_start<=first_attention_token), | ||||||
| # block x+1, .., block N (N_end <=max_len), ...] | ||||||
|
||||||
| local_attention_start_block_idx = \ | ||||||
| local_attention_start_idx // kv_cache_spec.block_size | ||||||
luccafong marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( | ||||||
| [block_pool.null_block] * local_attention_start_block_idx | ||||||
| for _ in range(len(kv_cache_group_ids))) | ||||||
|
|
||||||
| for i in range(local_attention_start_block_idx, max_num_blocks): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the rule of cache hit? For example, block_size 1 and chunk_size 2, what is the expect result of the following cases?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it mark computed blocks = previous unattended blocks + # of hit blocks, so even zero hit, it return the previous unattended blocks. So in your questions here:
For case like [miss, miss][miss miss][hit miss] it return 5. I will add more comments to explain. |
||||||
| block_hash = block_hashes[i] | ||||||
| if cached_block := block_pool.get_cached_block( | ||||||
| block_hash, kv_cache_group_ids): | ||||||
| for computed, cached in zip(computed_blocks, cached_block): | ||||||
| computed.append(cached) | ||||||
| else: | ||||||
| break | ||||||
| if use_eagle and computed_blocks[0]: | ||||||
| for computed in computed_blocks: | ||||||
| computed.pop() | ||||||
|
||||||
| return computed_blocks | ||||||
|
|
||||||
| def remove_skipped_blocks(self, request_id: str, | ||||||
| num_computed_tokens: int) -> None: | ||||||
| # Remove the blocks that are no longer be in the | ||||||
| # chunked attention window and skipped | ||||||
| # during the attention computation. | ||||||
|
||||||
|
|
||||||
| # (N-1) // chunk_size * chunk_size | ||||||
| # [chunk 0][chunk 1]local_attention_start_idx ... current | ||||||
|
|
||||||
| local_attention_start_idx = ( | ||||||
| num_computed_tokens - | ||||||
| 1) // self.attention_chunk_size * self.attention_chunk_size | ||||||
| # 1024-> 0, 1025-> 1024 | ||||||
|
||||||
| first_useful_block_idx = local_attention_start_idx // self.block_size | ||||||
| # block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 | ||||||
|
||||||
| # block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 | |
| # if block size = 128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need this blocdids?
Uh oh!
There was an error while loading. Please reload this page.