diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 929c3b6a4906..09fe36beab0a 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -364,11 +364,21 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata, output=output + self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name + query, + key, + value, + output, + self.layer_name, ) return output.view(-1, hidden_size) else: @@ -689,6 +699,7 @@ def forward( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, + positions: torch.Tensor | None = None, output_shape: torch.Size | None = None, ) -> torch.Tensor: if self.use_direct_call: @@ -713,12 +724,19 @@ def forward( k_pe, self_kv_cache, attn_metadata, + positions=positions, output=output, ) return output else: return self.impl.forward( - self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + self, + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + positions=positions, ) else: if self.attn_backend.accept_output_buffer: @@ -729,6 +747,7 @@ def forward( k_pe, output, self.layer_name, + positions=positions, ) return output else: @@ -889,6 +908,7 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: @@ -906,6 +926,7 @@ def unified_attention_with_output( value, kv_cache, attn_metadata, + positions=positions, output=output, output_scale=output_scale, output_block_scale=output_block_scale, @@ -920,6 +941,7 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: @@ -940,6 +962,7 @@ def unified_mla_attention( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, ) -> torch.Tensor: wait_for_kv_layer_from_connector(layer_name) @@ -949,7 +972,9 @@ def unified_mla_attention( attn_metadata = attn_metadata[layer_name] self: MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + output = self.impl.forward( + self, q, kv_c_normed, k_pe, kv_cache, attn_metadata, positions=positions + ) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -960,6 +985,7 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -979,6 +1005,7 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: @@ -996,6 +1023,7 @@ def unified_mla_attention_with_output( k_pe, kv_cache, attn_metadata, + positions=positions, output=output, output_scale=output_scale, output_block_scale=output_block_scale, diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 4c81162d7d2b..cec6cf44b464 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -4,10 +4,20 @@ import torch +from vllm import envs from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.platforms import current_platform + + +def is_aiter_mla_rope_flush_cache_fusion_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_MLA + ) @dataclass @@ -104,6 +114,7 @@ def __init__( ) self.prefix = prefix + self.mla_attn.impl.rotary_emb = self.rotary_emb def forward_native( self, @@ -147,9 +158,10 @@ def forward_native( # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim :], k_pe - ) + if not is_aiter_mla_rope_flush_cache_fusion_enabled(): + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) if self.indexer and self.is_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) @@ -158,6 +170,9 @@ def forward_native( q, kv_c_normed, k_pe, + positions=( + positions if is_aiter_mla_rope_flush_cache_fusion_enabled() else None + ), output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), ) return self.o_proj(attn_out)[0] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index da56b5c9d3d2..17e814f5930c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -193,7 +193,6 @@ from typing import ClassVar, Generic, TypeVar import torch -from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops @@ -208,7 +207,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -248,30 +247,6 @@ flashinfer_available = False -def is_rocm_aiter_fp8bmm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP8BMM - and envs.VLLM_ROCM_USE_AITER - ) - - -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 - ) - - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn - ): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() - - logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -1064,7 +1039,7 @@ def __init__( self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads - def process_weights_after_loading(self, act_dtype: torch.dtype): + def get_and_maybe_dequant_weights(self, layer: LinearBase, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: @@ -1074,24 +1049,26 @@ def get_layer_weight(layer): f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." ) - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye( - layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device, - ) - dequant_weights = layer.quant_method.apply(layer, eye, bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + def process_weights_after_loading(self, act_dtype: torch.dtype): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = self.get_and_maybe_dequant_weights( + self.kv_b_proj, act_dtype + ).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1112,79 +1089,27 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - aiter_triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - aiter_triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - # Copy result - out.copy_(x) - else: - # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" - # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - # Adjust output buffer shape back to the original (B, N * V) - N, B, V = out.shape - out.resize_((B, N * V)) - out.copy_(out_new) # Copy result + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): @@ -1435,52 +1360,10 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - aiter_triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - aiter_triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -1730,6 +1613,7 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: M, + positions: torch.Tensor | None = None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, @@ -1830,32 +1714,20 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm( - decode_q_nope, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True, - ) - else: - # Pads the head_dim if necessary (for the underlying kernel) - N, B, P = decode_q_nope.shape - _, _, L = self.W_UK_T.shape - if self.q_pad_num_heads is not None: - decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L) - ) - decode_ql_nope.resize_((N, B, L)) + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty((self.q_pad_num_heads, B, L)) + decode_ql_nope.resize_((N, B, L)) - else: - decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: ql_nope_shape = decode_ql_nope.shape diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d935c02243bd..61a37de14d01 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,14 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, TypeVar import torch +import tqdm import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionLayer +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.model_executor.layers.mla import is_aiter_mla_rope_flush_cache_fusion_enabled +from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -26,6 +32,34 @@ def is_aiter_mla_enabled() -> bool: return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP8BMM + and envs.VLLM_ROCM_USE_AITER + ) + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 + ) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + ): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +if is_aiter_mla_rope_flush_cache_fusion_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + + class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -61,6 +95,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass +M = TypeVar("M", bound=AiterMLAMetadata) + + class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 @@ -222,6 +259,7 @@ def __init__( "alibi_slopes, sliding_window, logits_soft_cap" ) + self.dcp_world_size: int | None from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func @@ -240,6 +278,107 @@ def _flash_attn_varlen_diff_headdims( return output + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) + else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result + + def process_weights_after_loading(self, act_dtype: torch.dtype): + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = self.get_and_maybe_dequant_weights( + self.kv_b_proj, act_dtype + ).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype() + ) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype() + ) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + def _forward_decode( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -277,3 +416,197 @@ def _forward_decode( ) return o, None + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + positions: torch.Tensor | None = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + if is_aiter_mla_rope_flush_cache_fusion_enabled(): + assert positions is not None + cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1) + is_neox = self.rotary_emb.is_neox_style + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q = fused_qk_rope_cat_and_cache_mla( + q_nope, + q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + positions, + cos, + sin, + layer._k_scale, + is_neox, + ) + else: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + ) + + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + # Pads the head_dim if necessary (for the underlying kernel) + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L) + ) + decode_ql_nope.resize_((N, B, L)) + + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + return output_padded