diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index f7a4114a0a70..30e5cafe0c84 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -13,221 +13,202 @@ AttentionType, MultipleOf, ) +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + split_decodes_prefills_and_extends, ) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 if current_platform.is_rocm(): import aiter + from aiter.ops.triton.utils.device_info import get_num_sms from vllm.triton_utils import tl, triton - from vllm.utils.torch_utils import direct_register_custom_op + + def block_size(x, head_dim): + return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) + + def num_programs(head_dim): + return min(head_dim, get_num_sms()) @triton.jit - def _vllm_layout_trans_kernel( - k_buffer_ptr, - v_buffer_ptr, - k_values_ptr, - v_values_ptr, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table_stride_0, - k_scale, - v_scale, - output_dtype: tl.constexpr, - E_DIM: tl.constexpr, + def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] + seq_start_ptr, # [num_batches] + k_scale_ptr, + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + num_tokens, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, BLOCK_SIZE: tl.constexpr, + NUM_PRGMS: tl.constexpr, ): - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2)) - batch_query_start, batch_query_end = tl.split(batch_query_indexes) - query_len = batch_query_end - batch_query_start - - if query_len <= 1: - return - - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) - batch_token_start, batch_token_end = tl.split(batch_token_indexes) - seq_len = batch_token_end - batch_token_start - - if block_idx * BLOCK_SIZE < seq_len: - block_mask = ( - block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] - ) < seq_len - - kv_idx = tl.load( - block_table + batch_idx * block_table_stride_0 + block_idx - ).to(tl.int64) - - kv_buffer_off = ( - kv_idx * BLOCK_SIZE * E_DIM - + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM - + tl.arange(0, E_DIM)[None, :] + bid = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + + for token_id in tl.range(bid, num_tokens, NUM_PRGMS): + key_ptr_offset = key_ptr + token_id * head_size * num_heads + value_ptr_offset = value_ptr + token_id * head_size * num_heads + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load( + block_table_ptr + max_block_num * batch_idx + block_offset ) - k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) - if k_vals.dtype.is_fp8(): - k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype) - else: - k_vals = k_vals.to(output_dtype) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + ) - v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) - if v_vals.dtype.is_fp8(): - v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype) - else: - v_vals = v_vals.to(output_dtype) - kv_values_off = ( - batch_token_start * E_DIM - + block_idx * BLOCK_SIZE * E_DIM - + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM - + tl.arange(0, E_DIM)[None, :] - ) - tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask) - tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) - - def vllm_layout_trans( - b_query_lens_loc, - b_seq_lens_loc, - block_table, - k_cache, - v_cache, - max_seq_len, - k_scale, - v_scale, - output_dtype, - total_tokens, + for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): + mask = (col_offsets + i) < head_size * num_heads + k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) + v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) + tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) + + def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: torch.Tensor, + v_scales: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int, ): - H_KV = v_cache.shape[2] - D = v_cache.shape[3] - BLOCK_SIZE = v_cache.shape[1] - - k_values = torch.empty( - (total_tokens, H_KV, D), - dtype=output_dtype, - device=k_cache.device, + assert kv_cache_layout in ["v0", "NHD", "HND"], ( + "kv_cache_layout only support v0, NHD, HND" ) - v_values = torch.empty( - (total_tokens, H_KV, D), - dtype=output_dtype, - device=v_cache.device, + head_dim = key.shape[2] + x = 0 + # assert dequant is True, "Currently, we only support "\ + # "gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + assert kv_cache_layout == "NHD", ( + "ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now" ) - - grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) - - if output_dtype == torch.float16: - output_dtype = tl.float16 - elif output_dtype == torch.bfloat16: - output_dtype = tl.bfloat16 - else: - raise ValueError(f"Unsupported output dtype: {output_dtype}") - - _vllm_layout_trans_kernel[grid]( - k_cache, - v_cache, - k_values, - v_values, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table.stride(0), - k_scale, - v_scale, - output_dtype=output_dtype, - E_DIM=H_KV * D, + assert head_dim == key_cache.shape[3], ( + "We assume your kv cache layout is [num_blocks, " + "page_size, num_heads, head_dim], but got otherwise" + ) + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + NUM_PRGMS = num_programs(total_tokens) + BLOCK_SIZE = block_size(key_cache, head_dim) + grid = lambda meta: (NUM_PRGMS,) + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + total_tokens, + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, BLOCK_SIZE=BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, ) - return k_values, v_values - def flash_attn_varlen_func_impl( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - out: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - window_size: list[int] | None, # -1 means infinite context window - alibi_slopes: list[float] | None, - block_table: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - total_tokens: int = 0, - ) -> torch.Tensor: - if total_tokens == 0: - total_tokens = int(cu_seqlens_k[-1].item()) - k, v = vllm_layout_trans( - cu_seqlens_q, - cu_seqlens_k, - block_table, - k_cache, - v_cache, - max_seqlen_k, - k_scale, - v_scale, - q.dtype, - total_tokens, - ) +logger = init_logger(__name__) - output = aiter.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - min_seqlen_q=1, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - window_size=window_size, - out=out, - ) - return output - def flash_attn_varlen_func_fake( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - out: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - window_size: list[int] | None, # -1 means infinite context window - alibi_slopes: list[float] | None, - block_table: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - total_tokens: int = 0, - ) -> torch.Tensor: - return torch.empty( - q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device - ) +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor - direct_register_custom_op( - "flash_attn_varlen_func", - flash_attn_varlen_func_impl, - ["out"], - flash_attn_varlen_func_fake, - dispatch_key=current_platform.dispatch_key, - ) -logger = init_logger(__name__) +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata @dataclass @@ -248,7 +229,18 @@ class AiterFlashAttentionMetadata: seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor - cu_seq_lens: torch.Tensor | None + + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_extends: int + num_extend_tokens: int + + decode_metadata: AiterFlashAttentionDecodeMetadata | None + prefill_metadata: AiterFlashAttentionPrefillMetadata | None + extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None # For cascade attention. use_cascade: bool @@ -260,6 +252,7 @@ class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] ): cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + reorder_batch_threshold: int = 1 def __init__( self, @@ -285,6 +278,12 @@ def __init__( self.aot_sliding_window: tuple[int, int] | None = None self.total_tokens: int = 0 + self.extend_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], + dtype=self.model_config.dtype, + device=device, + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ): @@ -302,42 +301,139 @@ def build( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> "AiterFlashAttentionMetadata": - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - if max_query_len > 1: - # We pre-compute cumulative seq len needed for prefill attention - # here to avoid recomputing it for every layer + split_ret = split_decodes_prefills_and_extends( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], + ) + + prefill_metadata = None + if num_prefills > 0: + query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_extends : + ] + prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_prefill.max().item(), + min_query_len=query_lens_for_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + ) + + extend_metadata = None + if num_extends > 0: + num_extends_slice = slice(num_decodes, num_decodes + num_extends) + query_lens_for_extend = query_lens_cpu[num_extends_slice] + seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice] + computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_extends) + * max_context_chunk + ) + chunk_ends = torch.min( + computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0 + ) # [num_chunks, num_extends] + cu_seq_lens_cpu = torch.zeros( + [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :] + idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None] + idx_to_batch_tensor = idx_to_batch_tensor.sum( + dim=1 + ) # [num_chunks, max_cum_tokens] + token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1) + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.extend_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes : num_decodes + num_extends + 1 + ] + seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] cu_seq_lens = torch.zeros( - seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device + num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + ) + torch.cumsum( + seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] + ) + extend_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_extend.max().item(), + min_query_len=query_lens_for_extend.min().item(), + max_seq_len=seq_lens[num_extends_slice].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata, ) - torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) - num_actual_kv_tokens = int(cu_seq_lens[-1].item()) - else: - cu_seq_lens = None - num_actual_kv_tokens = 0 - def schedule( - batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal - ): - return None + num_actual_kv_tokens = torch.sum(seq_lens).item() use_cascade = common_prefix_len > 0 attn_metadata = AiterFlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - cu_seq_lens=cu_seq_lens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_extends=num_extends, + num_extend_tokens=num_extend_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata, + extend_metadata=extend_metadata, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -401,6 +497,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -449,6 +546,110 @@ def __init__( "FlashAttentionImpl" ) + def extend_forward( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: float, + v_scale: float, + ): + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + assert attn_metadata.extend_metadata is not None + chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=False, + kv_cache_layout="NHD", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse, + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + def forward( self, layer: torch.nn.Module, @@ -488,24 +689,25 @@ def forward( return output.fill_(0) # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. + # NOTE(woosuk): With piece-wise CUDA graphs, this method is + # executed in eager-mode PyTorch. Thus, we need to be careful + # about any CPU overhead in this method. For example, `view` + # and `slice` (or `[:n]`) operations are surprisingly slow even + # in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, @@ -521,69 +723,118 @@ def forward( key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) + # decode:extend:prefill + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.num_decodes + num_prefills = attn_metadata.num_prefills + num_extends = attn_metadata.num_extends + + num_decode_tokens = attn_metadata.num_decode_tokens + num_extend_tokens = attn_metadata.num_extend_tokens if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if max_seqlen_q > 1: - torch.ops.vllm.flash_attn_varlen_func( - query[:num_actual_tokens], - key_cache, - value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + # calculate for pure prefills + if num_prefills > 0: + assert attn_metadata.prefill_metadata is not None + + prefill_query = query[num_decode_tokens + num_extend_tokens :] + prefill_key = key[num_decode_tokens + num_extend_tokens :] + prefill_value = value[num_decode_tokens + num_extend_tokens :] + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len, + min_seqlen_q=attn_metadata.prefill_metadata.min_query_len, + dropout_p=0.0, softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, + causal=True, window_size=self.sliding_window, - block_table=block_table, - cu_seqlens_k=attn_metadata.cu_seq_lens, + alibi_slopes=self.alibi_slopes, + out=output_actual_tokens[num_decode_tokens + num_extend_tokens :], + ) + + # calculate for extends + if num_extends > 0: + assert attn_metadata.extend_metadata is not None + extend_tokens_slice = slice( + num_decode_tokens, num_decode_tokens + num_extend_tokens + ) + extend_querys = query[extend_tokens_slice] + extend_keys = key[extend_tokens_slice] + extend_values = value[extend_tokens_slice] + extend_outputs = output[extend_tokens_slice] + self.extend_forward( + attn_metadata=attn_metadata, + query=extend_querys, + key=extend_keys, + value=extend_values, + key_cache=key_cache, + value_cache=value_cache, + output=extend_outputs, + cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc, + max_seqlen_q=attn_metadata.extend_metadata.max_query_len, + max_seqlen_k=attn_metadata.extend_metadata.max_seq_len, + min_seqlen_q=attn_metadata.extend_metadata.min_query_len, + block_table=attn_metadata.block_table[ + num_decodes : num_decodes + num_extends + ], + slot_mapping=attn_metadata.slot_mapping[ + num_decodes : num_decodes + num_extends + ], k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.num_actual_kv_tokens, ) - _, num_heads, head_size = query.shape - nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - num_seqs = seqused_k.shape[0] - max_num_partitions = ( - max_seqlen_k + _PARTITION_SIZE_ROCM - 1 - ) // _PARTITION_SIZE_ROCM - - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) - * nbytes_per_qo_elem - + 2 * (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, - device=output.device, - ) + # calculate for decodes + if num_decodes > 0: + assert attn_metadata.decode_metadata is not None + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + num_seqs = attn_metadata.seq_lens.shape[0] + max_num_partitions = ( + attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + + workspace_buffer = torch.empty( + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) - torch.ops.aiter.paged_attention_v1( - output[:num_actual_tokens], - workspace_buffer, - query[:num_actual_tokens], - key_cache, - value_cache, - self.scale, - block_table, - cu_seqlens_q, - seqused_k, - max_seqlen_k, - self.alibi_slopes, - self.kv_cache_dtype, - "NHD", - self.logits_soft_cap, - layer._k_scale, - layer._v_scale, - None, - _PARTITION_SIZE_ROCM, - ) - return output + torch.ops.aiter.paged_attention_v1( + output[:num_decode_tokens], + workspace_buffer, + query[:num_decode_tokens], + key_cache, + value_cache, + self.scale, + attn_metadata.block_table[:num_decodes], + attn_metadata.query_start_loc[:num_decodes], + attn_metadata.seq_lens[:num_decodes], + attn_metadata.max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + "NHD", + self.logits_soft_cap, + layer._k_scale, + layer._v_scale, + None, + _PARTITION_SIZE_ROCM, + ) else: raise NotImplementedError( "Cascade attention is not implemented for ROCM AITER" ) + + return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 07d62e9849e0..6c750d3448c4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -728,6 +728,73 @@ def subclass_attention_backend( ) +def split_decodes_prefills_and_extends( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_extends: The number of extend requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_extend_tokens: The number of tokens in the extend requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, 0, num_tokens, 0, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill_or_extend = query_lens > decode_threshold + is_prefill = (seq_lens == query_lens) & is_prefill_or_extend + first_extend = is_prefill_or_extend.int().argmax(dim=-1).item() + first_prefill = is_prefill.int().argmax(dim=-1).item() + num_decodes = first_extend + num_decode_tokens = query_start_loc[first_extend].item() + if not torch.any(is_prefill_or_extend): + return (num_decodes, 0, 0, num_decode_tokens, 0, 0) + + num_prefills_or_extends = num_reqs - num_decodes + num_prefill_or_extend_tokens = num_tokens - num_decode_tokens + if not torch.any(is_prefill): + return ( + num_decodes, + num_prefills_or_extends, + 0, + num_decode_tokens, + num_prefill_or_extend_tokens, + 0, + ) + + num_extends = first_prefill - num_decodes + num_prefills = num_reqs - first_prefill + + num_prefill_tokens = num_tokens - query_start_loc[first_prefill] + num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens + return ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1,