Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 63 additions & 40 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0

# Authors:
# - Burkhard Ringlein
# - Jan van Lunteren
# - Thomas Parnell
# - Burkhard Ringlein <[email protected]>
# - Jan van Lunteren <[email protected]>
# - Chih-Chieh Yang <[email protected]>
# - Thomas Parnell <[email protected]>

import torch
import triton
Expand Down Expand Up @@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
v_scale, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.constexpr, # int
query_stride_0: tl.constexpr, # int
query_stride_1: tl.constexpr, # int, should be equal to head_size
Expand All @@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
query_start_len_ptr, # [num_seqs+1]
):
seq_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
kv_head_idx = query_head_idx // num_queries_per_kv
kv_head_idx = tl.program_id(1)

if filter_by_query_len:
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
Expand All @@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
else:
cur_batch_in_all_start_index = seq_idx

query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
0, num_queries_per_kv_padded)

query_offset = (cur_batch_in_all_start_index * query_stride_0 +
query_head_idx * query_stride_1)
query_head_idx[:, None] * query_stride_1)

head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
head_mask = head_mask & (query_head_idx < num_query_heads)

dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
0).to(tl.int1)

# Q : (HEAD_SIZE,)
# Q : (num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED),
mask=dim_mask,
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None],
other=0.0,
)

block_table_offset = seq_idx * block_table_stride

M = tl.full([1], float("-inf"), dtype=tl.float32)
L = tl.full([1], 1.0, dtype=tl.float32)
acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32)
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32)

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)

# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx)
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
mask=head_mask,
other=0.0)

num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)

Expand All @@ -107,8 +117,8 @@ def kernel_paged_attention_2d(

v_offset = (physical_block_idx * stride_v_cache_0 +
kv_head_idx * stride_v_cache_1 +
offs_d[:, None] * stride_v_cache_2 +
offs_n[None, :] * stride_v_cache_3)
offs_d[None, :] * stride_v_cache_2 +
offs_n[:, None] * stride_v_cache_3)

k_offset = (physical_block_idx * stride_k_cache_0 +
kv_head_idx * stride_k_cache_1 +
Expand All @@ -126,61 +136,69 @@ def kernel_paged_attention_2d(
else:
K = K_load

# V : (HEAD_SIZE, BLOCK_SIZE)
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[:, None],
mask=dim_mask[None, :],
other=0.0)

if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load

tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
mask_new = tmp < boundary
# S : (BLOCK_SIZE,)
S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32)
S += scale * tl.sum(K * Q[:, None], axis=0)
seq_mask = seq_offset[None, :] < boundary

# S : (num_queries_per_kv, BLOCK_SIZE,)
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
float("-inf")).to(tl.float32)
S += scale * tl.dot(Q, K)

context_len = seq_len - 1

if SLIDING_WINDOW > 0:
S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000)
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
-10000)

if USE_ALIBI_SLOPES:
S += alibi_slope * (tmp - seq_len + 1)
S += alibi_slope[:, None] * (seq_offset - context_len)

# compute running maximum
# m_j : (1,)
m_j = tl.maximum(M, tl.max(S, axis=0))
# m_j : (num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))

# P : (BLOCK_SIZE,)
P = tl.exp(S - m_j)
# P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])

# l_j : (1,)
l_j = tl.sum(P, axis=0)
# l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1)

# alpha : (1, )
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)

# acc : (BLOCK_SIZE,)
acc = acc * alpha
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]

# update constants
L = L * alpha + l_j
M = m_j

# acc : (BLOCK_SIZE,)
acc += tl.sum(V * P[None, :], axis=1)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)

# epilogue
acc = acc / L
acc = acc / L[:, None]

output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)

tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED),
acc,
mask=dim_mask)
tl.store(
output_ptr + output_offset[:, None] +
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
acc,
mask=dim_mask[None, :] & head_mask[:, None],
)


def chunked_prefill_paged_decode(
Expand Down Expand Up @@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
block_size = value_cache.shape[3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to understand it right, should be a return after call of context_attention_fwd? otherwise for max_query_len > 1 you are calling two kernels that might compute the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we added this option to the context_attention_fwd on main, which if enabled will skip the sequences in the batch with query_length=1. We then launch another kernel concurrently to handle the ones that were skipped.

num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
head_size = query.shape[2]

Expand All @@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
key_cache = key_cache.view(target_dtype)
value_cache = value_cache.view(target_dtype)

num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)

kernel_paged_attention_2d[(
num_seqs,
num_query_heads,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
Expand All @@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
Expand Down