Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions benchmarks/kernels/benchmark_reshape_and_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

import random
import time

import torch
from tabulate import tabulate

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)

logger = init_logger(__name__)


@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
num_iters: int,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""

if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")

current_platform.seed_everything(42)
torch.set_default_device(device)

# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)

# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)

key_caches, value_caches = create_kv_caches_with_random(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches

# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)

function_under_test = lambda: ops.reshape_and_cache(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)

if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()

def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters

# warm-up
run_cuda_benchmark(3)

lat = run_cuda_benchmark(num_iters)

# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()

return lat


def main(args):
rows = []
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
num_iters=args.iters,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, "HDN", f"{lat * 1e6:.3f}"])

print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))


if __name__ == "__main__":
parser = FlexibleArgumentParser()

parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 512)

parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)

parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)

parser.add_argument("--iters", type=int, default=100)

parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)

args = parser.parse_args()

main(args)
102 changes: 56 additions & 46 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,73 +209,82 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,

namespace vllm {

// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;

__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
// block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
// block_size]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size//x,
// x]
const scalar_t* __restrict__ value, // [num_tokens, num_heads,
// head_size//x, x]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size//x,
// block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size//x,
// x, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}

const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int h_block_count = head_size / x; // head_size//x

const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
const int h_block_idx = threadIdx.x;
if (h_block_idx >= num_heads * h_block_count) {
return;
}
}

// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
const int head_idx = h_block_idx / h_block_count;
const int h_block = h_block_idx % h_block_count;

__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
const scalar_t* __restrict__ key_src =
key + token_idx * key_stride + head_idx * head_size + h_block * x;
const int64_t src_value_start =
token_idx * value_stride + head_idx * head_size + h_block * x;

cache_t* __restrict__ key_dst =
key_cache + block_idx * num_heads * h_block_count * block_size * x +
head_idx * h_block_count * block_size * x + h_block * block_size * x +
block_offset * x;
const int64_t tgt_value_start =
block_idx * num_heads * h_block_count * x * block_size +
head_idx * h_block_count * x * block_size + h_block * x * block_size +
block_offset;

float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0, 1, k_op);

#pragma unroll
for (int i = 0; i < x; i++) {
scalar_t value_val = value[src_value_start + i];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
value_cache[tgt_value_start + i * block_size] = value_val;
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
value_cache[tgt_value_start + i * block_size] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(value_val, *v_scale);
}
}
};
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
Expand Down Expand Up @@ -431,9 +440,10 @@ void reshape_and_cache(

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int head_div_x = head_size / x;

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
dim3 block(std::min(num_heads * head_div_x, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Expand Down