Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c6a2d25
copy manager code
heheda12345 Apr 1, 2025
4b27c82
save
heheda12345 Apr 5, 2025
4dea38d
can run
heheda12345 Apr 5, 2025
55720e0
can pass e2e tests
heheda12345 Apr 5, 2025
273dd44
run precommit
heheda12345 Apr 5, 2025
0bfec8d
can run again
heheda12345 Apr 5, 2025
df31d7a
quick copy
heheda12345 Apr 6, 2025
1ce3023
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 23, 2025
6aee98d
a runable version
heheda12345 Apr 23, 2025
7f19466
fix bug
heheda12345 Apr 23, 2025
34ba571
1 hash per block_size
heheda12345 Apr 23, 2025
18245e3
one manager for each type
heheda12345 Apr 24, 2025
2c81fe6
small update
heheda12345 Apr 24, 2025
42a8244
small fix
heheda12345 Apr 24, 2025
8af9ace
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 25, 2025
6493e5e
fix gemma
heheda12345 Apr 25, 2025
fa224f2
Merge branch 'fix_gemma' of github.com:heheda12345/vllm into hybrid_mem
heheda12345 Apr 25, 2025
c512bc5
update attn backends
heheda12345 Apr 25, 2025
4ce3424
fix flashinfer type
heheda12345 Apr 25, 2025
d17843e
fix flashmla type
heheda12345 Apr 25, 2025
47ec1a7
fix triton type
heheda12345 Apr 25, 2025
840675f
clean up slidingwindowspec
heheda12345 Apr 25, 2025
ffcbde8
clean up block table
heheda12345 Apr 25, 2025
4eebce7
clean up runner (WIP)
heheda12345 Apr 25, 2025
4380fa6
add notes
heheda12345 Apr 25, 2025
b567c56
clean up attn_metadata read in runner
heheda12345 Apr 25, 2025
710d68e
reorder attn args
heheda12345 Apr 25, 2025
2b8ffc4
rename max_num_tokens
heheda12345 Apr 25, 2025
b50aa14
remove fixed TODO
heheda12345 Apr 25, 2025
136a54c
fix
heheda12345 Apr 25, 2025
765d9ed
add note
heheda12345 Apr 25, 2025
216a079
group partition strategy
heheda12345 Apr 26, 2025
1c66541
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 26, 2025
37c4494
support eagle
heheda12345 Apr 26, 2025
84280fc
get a specific type of layer from forward context
heheda12345 Apr 26, 2025
51ffeb6
fix
heheda12345 Apr 26, 2025
7d03c88
Merge branch 'filter_fwd_ctx' of github.com:heheda12345/vllm into hyb…
heheda12345 Apr 26, 2025
1da28d9
update eagle
heheda12345 Apr 26, 2025
e5cb02e
only enable cuda platform
heheda12345 Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def forward(
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
Expand All @@ -222,6 +224,8 @@ def forward(
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
Expand Down Expand Up @@ -337,6 +341,8 @@ def unified_attention(
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
Expand Down Expand Up @@ -369,6 +375,8 @@ def unified_attention_with_output(
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,8 @@ class CacheConfig:
sliding_window: Sliding window size for the KV cache.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
disable_hybrid_allocator: Whether to disable the hybrid allocator (Only
affects v1).
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -1153,6 +1155,7 @@ def __init__(
prefix_caching_hash_algo: str = "builtin",
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
disable_hybrid_allocator: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
Expand All @@ -1165,6 +1168,7 @@ def __init__(
self.prefix_caching_hash_algo = prefix_caching_hash_algo
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self.disable_hybrid_allocator = disable_hybrid_allocator
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class EngineArgs:
model_impl: str = "auto"

calculate_kv_scales: Optional[bool] = None
disable_hybrid_allocator: bool = False

additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
Expand Down Expand Up @@ -964,6 +965,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')

parser.add_argument(
"--disable-hybrid-allocator",
action="store_true",
default=False,
help="Disable the hybrid allocator. This only affects v1.")

parser.add_argument(
"--additional-config",
type=json.loads,
Expand Down Expand Up @@ -1173,6 +1180,7 @@ def create_engine_config(
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
disable_hybrid_allocator=self.disable_hybrid_allocator,
)

# Get the current placement group if Ray is initialized and
Expand Down
11 changes: 8 additions & 3 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -34,8 +34,13 @@ class DPMetadata:
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to
AttentionMetadata of that layer
set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

Expand Down Expand Up @@ -99,23 +101,23 @@ class FlashAttentionMetadata:

class FlashAttentionMetadataBuilder:

def __init__(self, runner: "GPUModelRunner"):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: KVCacheSpec,
persistent_block_table: BlockTable):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int, block_table: BlockTable):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

use_cascade = common_prefix_len > 0
Expand All @@ -142,7 +144,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
Expand Down
16 changes: 11 additions & 5 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

try:
Expand Down Expand Up @@ -342,6 +344,8 @@ class MLACommonMetadataBuilder(Generic[M]):

def __init__(self,
runner: "GPUModelRunner",
kv_cache_spec: KVCacheSpec,
persistent_block_table: BlockTable,
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
Expand Down Expand Up @@ -375,7 +379,8 @@ def __init__(self,
dtype=model_config.dtype,
device=runner.device,
)
self.page_size = self.runner.block_size
self.page_size = kv_cache_spec.block_size
self.persistent_block_table = persistent_block_table

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand Down Expand Up @@ -455,12 +460,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
block_table = (self.persistent_block_table.block_table.
get_device_tensor()[:num_reqs])
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True)
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
slot_mapping = (self.persistent_block_table.
slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long())
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()

Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.worker.block_table import BlockTable

logger = init_logger(__name__)

Expand Down Expand Up @@ -52,8 +54,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):

class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):

def __init__(self, runner):
super().__init__(runner)
def __init__(self, runner, kv_cache_spec: KVCacheSpec,
persistent_block_table: BlockTable):
super().__init__(runner, kv_cache_spec, persistent_block_table)

self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def cache_full_blocks(
num_full_blocks: int,
block_size: int,
hash_fn: Callable,
kv_cache_group_id: int = -1,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
Expand All @@ -101,6 +102,8 @@ def cache_full_blocks(
be cached after this function.
block_size: Number of tokens in each block.
hash_fn: The hash function to use for block hashes.
kv_cache_group_id: The id of the kv cache group. -1 means no kv
cache group.
"""
if num_cached_blocks == num_full_blocks:
return
Expand Down Expand Up @@ -143,7 +146,8 @@ def cache_full_blocks(
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
request, start_token_idx, end_token_idx, -1,
kv_cache_group_id)

# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
Expand Down
Loading