Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7913145
support local chunked attentin for hybrid
luccafong Jun 2, 2025
38f90b5
revert gemma change
luccafong Jun 12, 2025
78385bc
fix and add comment
Jun 16, 2025
320ab71
add attention_chunk_size in full attention spec
Jun 16, 2025
1bbff13
Merge remote-tracking branch 'origin/main' into hybrid-local-attn
Jun 16, 2025
076ab27
Merge remote-tracking branch 'origin/main' into hybrid-local-attn
Jun 16, 2025
8b7b409
fix use_cascade_attention signature
Jun 16, 2025
f2887f6
add assertions and merge impl
Jun 18, 2025
f7b6961
fix issue of local attention start idx based on num computed tokens
Jun 18, 2025
f60bea5
Merge remote-tracking branch 'origin/main' into hybrid-local-attn
Jun 19, 2025
2c4c8c4
address comments
Jun 24, 2025
a234791
Merge remote-tracking branch 'origin/main' into hybrid-local-attn
Jun 24, 2025
472cb24
fix format
Jun 24, 2025
58f82e2
Merge remote-tracking branch 'main/main' into hybrid-local-attn
luccafong Jul 11, 2025
23c7fba
fix the block skipping issue
luccafong Jul 14, 2025
5a0ccdf
adding example and comment
luccafong Jul 14, 2025
5d83152
remove prints
luccafong Jul 14, 2025
fcdee50
Merge remote-tracking branch 'origin/main' into hybrid-local-attn
luccafong Jul 14, 2025
4655102
fix format
luccafong Jul 14, 2025
326c0e5
disabl hybrid kv cache for eagle + chunked local attention
luccafong Jul 16, 2025
ad1169c
Merge remote-tracking branch 'vllm/main' into hybrid-local-attn
luccafong Jul 16, 2025
427a219
fix the ci teset
luccafong Jul 17, 2025
07f3353
check if local attention through beckend
luccafong Jul 17, 2025
1b41951
fix to use impl
luccafong Jul 17, 2025
d48fb3c
Merge remote-tracking branch 'vllm/main' into hybrid-local-attn
luccafong Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions tests/v1/core/test_specialized_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import random

import torch

from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
SlidingWindowSpec)


def get_sliding_window_manager(sliding_window_spec, block_pool):
Expand All @@ -17,6 +21,78 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
kv_cache_group_id=0)


def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)


def test_chunked_local_attention_possible_cached_prefix():
block_size = 2
chunked_local_attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool)

def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
]

block_pool.cached_block_hash_to_block.clear()

# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}

computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,
max_length=len(block_hash_list) * block_size + tail_token,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False)[0]
assert len(computed_blocks) == expect_length

assert all(block == block_pool.null_block
for block in computed_blocks[:(expect_length - 1) // 2])

run_one_case([True], 0, 1)
run_one_case([True], 1, 1)
run_one_case([True, False], 0, 1)
run_one_case([True, False], 1, 2)
run_one_case([True, True], 0, 2)
run_one_case([True, True], 1, 2)
run_one_case([True, True, False], 0, 2)
run_one_case([True, True, False], 1, 2)
run_one_case([True, True, True], 0, 3)
run_one_case([True, True, True], 1, 3)
run_one_case([True, True, True, False], 0, 3)
run_one_case([True, True, True, False], 1, 4)
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 9)
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 8)
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 8)
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)


def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec(
Expand Down Expand Up @@ -84,6 +160,63 @@ def run_one_case(block_is_cached, expect_length):
], 8)


def test_chunked_local_attention_remove_skipped_blocks():
attention_spec = ChunkedLocalAttentionSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)

manager = get_chunked_local_attention_manager(attention_spec, block_pool)

null_block_id = block_pool.null_block.block_id

def id_to_block_table(ids) -> list[KVCacheBlock]:
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]

def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block
else:
assert block.block_id == id_

original_block_ids = [
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
manager.req_to_blocks["test"] = block_table

manager.remove_skipped_blocks("test", 0)
assert_block_id(block_table, original_block_ids)

# 4 tokens are computed. no token is out of the local attention window.
manager.remove_skipped_blocks("test", 4)
assert_block_id(block_table, original_block_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test, as token 4 doesn't need the kv cache of tokens [0-3], why do you need to keep them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token 4 (if 1 indexed) need kv cache of [0-4],


# 5 tokens are computed. token 0 is out of the local attention window.
# no block can be removed.
manager.remove_skipped_blocks("test", 5)
assert_block_id(block_table, [null_block_id])

# 6 tokens are computed. token 4 - 5 are in local attention window,
# token 0 - 3 are out, 2 blocks can be removed.
manager.remove_skipped_blocks("test", 6)
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# 11 tokens are computed. token 8 - 11 are in local attention window,
# token 0-7 are out, 4 block can be removed.
manager.remove_skipped_blocks("test", 11)
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])


def test_sliding_window_remove_skipped_blocks():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
Expand Down Expand Up @@ -172,3 +305,26 @@ def test_get_num_blocks_to_allocate():
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15


def test_chunked_local_attention_get_num_blocks_to_allocate():
block_size = 2
attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
] + [KVCacheBlock(i + 1) for i in range(5)]

assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15
1 change: 1 addition & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
self.use_irope = extra_impl_args.get("use_irope", False)

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def use_cascade_attention(
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
"""Decide whether to use cascade attention.
Expand Down
19 changes: 16 additions & 3 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
Expand Down Expand Up @@ -845,6 +846,7 @@ def _get_kv_cache_config_uniform_page_size(
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
# print(f"{page_size=}, {group_size=}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# print(f"{page_size=}, {group_size=}")

num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
Expand Down Expand Up @@ -904,7 +906,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec)
for spec in kv_cache_spec.values())
if has_full_attention and (has_sliding_window
or has_chunked_local_attention):
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
Expand All @@ -915,6 +921,14 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
elif isinstance(spec, ChunkedLocalAttentionSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
)

if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
Expand All @@ -938,7 +952,6 @@ def get_kv_cache_config(
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)

if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

Expand Down
93 changes: 90 additions & 3 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.request import Request

Expand Down Expand Up @@ -384,15 +385,101 @@ def get_num_common_prefix_blocks(self, request_id: str,
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
0 here for correctness. Need to support cascade attention + sliding
window in the future.
0 here for correctness. Need to support cascade attention in the future.
"""
return 0


class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):

def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
block_pool: BlockPool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block

@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
"ChunkedLocalAttentionManager can only be used for " +
"chunked local attentiongroups")
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = \
(max_length-1) // kv_cache_spec.attention_chunk_size \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -1 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need the actual index instead of the length here to calculate the actual attending window,
e.g. given a max length of 128, and chunk size = 64, the context chunked as [0, 63] and [64, 127], the 127th should attend with window [64, 127], where start idx=64 = (127// 64 * 64) instead of 2 * 64=128.

* kv_cache_spec.attention_chunk_size
else:
local_attention_start_idx = 0
# [ block 0, ..., block x(x_start<=first_attention_token),
# block x+1, .., block N (N_end <=max_len), ...]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this comment? what is x for?

local_attention_start_block_idx = \
local_attention_start_idx // kv_cache_spec.block_size
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[block_pool.null_block] * local_attention_start_block_idx
for _ in range(len(kv_cache_group_ids)))

for i in range(local_attention_start_block_idx, max_num_blocks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the rule of cache hit? For example, block_size 1 and chunk_size 2, what is the expect result of the following cases?

  1. [miss miss] [miss miss] [miss miss]. Should it be 0 or 6?
  2. [miss miss] [hit miss] [miss miss]. Should it be 3 or 6?
    And please add some comment to describe the expect behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah
For current token, we check from the first block that contains the attention window for cache hit until it miss.

it mark computed blocks = previous unattended blocks + # of hit blocks, so even zero hit, it return the previous unattended blocks.

So in your questions here:

  1. it return 4, since last window missed
  2. still 4 since last window missed.

For case like [miss, miss][miss miss][hit miss] it return 5.

I will add more comments to explain.

block_hash = block_hashes[i]
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In eagle, we can't simply pop the last block.
For example, chunk size 2 and block size 1:
[miss, miss] [miss miss] -> cache_hit_length 4
if we remove the 3-th block (0-indexed), the cache_hit_length becomes 3, but [miss, miss] [miss] is not a valid cache hit prefix. I think we should return cache_hit_length 2 in this case.

return computed_blocks

def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the
# chunked attention window and skipped
# during the attention computation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the comment to ~80 characters per line?


# (N-1) // chunk_size * chunk_size
# [chunk 0][chunk 1]local_attention_start_idx ... current

local_attention_start_idx = (
num_computed_tokens -
1) // self.attention_chunk_size * self.attention_chunk_size
# 1024-> 0, 1025-> 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 1024 -> 0? Does the attention of the 1024-th token (the first token of the next chunk) need tokens 0-1023?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here num_computed_tokens = 1024, so it is indexed 1023, which the local attention start from 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the comment?

first_useful_block_idx = local_attention_start_idx // self.block_size
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2
# if block size = 128, 0 -> block 0, 1024 -> block 8, 372 -> block 2

blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
blockids = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this blocdids?

for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blockids.append(i)
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)

def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
return 0


spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
}


Expand Down
Loading