Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions tests/kernels/attention/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]

RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
Expand Down Expand Up @@ -223,6 +225,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
Expand All @@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
implementation: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")

# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
Expand Down Expand Up @@ -298,12 +305,20 @@ def permute_and_compact(x):
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)

Expand Down
176 changes: 176 additions & 0 deletions vllm/attention/ops/triton_reshape_and_cache_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import triton
import triton.language as tl

from vllm.platforms import current_platform


@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):

token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return

tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs

block_idx = slot_idx // block_size
block_offset = slot_idx % block_size

src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride

tgt_idx = block_idx * block_stride + block_offset * page_stride

# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale)
else:
key_tile = key_load

# [TILE_SIZE]
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load

tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return


def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_tokens = key.shape[0]
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size

key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]

head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"

assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
kv_cache_dtype.startswith("fp8") else key_cache.dtype

if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
"fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
"uint8 is not supported by triton reshape_and_cache_flash"

FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
torch.float8_e4m3fnuz], \
"unsupported dtype of KV cache tensor, got "\
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."

# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if torch.version.hip:
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)

# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))

reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)
15 changes: 12 additions & 3 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -346,7 +348,13 @@ def forward(
layer._v_scale,
)
else:
ops.reshape_and_cache_flash(
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
Expand All @@ -358,8 +366,9 @@ def forward(
)

if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
Expand Down