Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def get_supported_head_sizes(cls) -> list[int]:

@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from math import prod

import torch
Expand Down Expand Up @@ -44,6 +44,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
raise NotImplementedError

def copy_with_new_block_size(self, block_size: int) -> Self:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return replace(self, block_size=block_size)

@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Expand Down
32 changes: 26 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3982,16 +3982,11 @@ def create_attn_groups(
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup.create_with_metadata_builders(
attn_group = AttentionGroup(
attn_backend,
layer_names,
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)

attn_groups.append(attn_group)
Expand All @@ -4010,7 +4005,28 @@ def create_attn_groups(
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))

def initialize_metadata_builders(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
if kv_cache_group_id == len(kernel_block_sizes):
# There may be a last group for layers without kv cache.
continue
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_sizes[kv_cache_group_id],
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()

def _check_and_update_cudagraph_mode(
Expand Down Expand Up @@ -4576,6 +4592,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)

# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(
Expand Down
42 changes: 23 additions & 19 deletions vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import torch
Expand Down Expand Up @@ -134,31 +134,35 @@ def reset_cache(self) -> None:
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)

@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
def create_metadata_builders(
self,
vllm_config,
device,
kernel_block_size: int,
num_metadata_builders: int = 1,
) -> "AttentionGroup":
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
):
kv_cache_spec_kernel = self.kv_cache_spec.copy_with_new_block_size(
kernel_block_size
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_kernel,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders)
]
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)

def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
Expand Down