Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,21 @@
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output,
)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
query,
key,
value,
output,
self.layer_name,
)
return output.view(-1, hidden_size)
else:
Expand Down Expand Up @@ -688,6 +698,7 @@
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
positions: Optional[torch.Tensor] = None,
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
if self.use_direct_call:
Expand All @@ -705,19 +716,26 @@

if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(

Check failure on line 719 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 719 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 719 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
positions=positions,
output=output,
)
return output
else:
return self.impl.forward(

Check failure on line 731 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 731 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 731 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
positions=positions,
)
else:
if self.attn_backend.accept_output_buffer:
Expand All @@ -728,6 +746,7 @@
k_pe,
output,
self.layer_name,
positions=positions,
)
return output
else:
Expand Down Expand Up @@ -888,6 +907,7 @@
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
positions: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
Expand All @@ -905,6 +925,7 @@
value,
kv_cache,
attn_metadata,
positions=positions,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
Expand All @@ -919,6 +940,7 @@
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
positions: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
Expand All @@ -939,6 +961,7 @@
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)

Expand All @@ -948,7 +971,9 @@
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
output = self.impl.forward(

Check failure on line 974 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 974 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]
self, q, kv_c_normed, k_pe, kv_cache, attn_metadata, positions=positions
)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
Expand All @@ -959,6 +984,7 @@
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()

Expand All @@ -978,6 +1004,7 @@
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
positions: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
Expand All @@ -988,13 +1015,14 @@
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(

Check failure on line 1018 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]

Check failure on line 1018 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "positions" for "forward" of "MLAAttentionImpl" [call-arg]
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
positions=positions,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
Expand Down
21 changes: 18 additions & 3 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@

import torch

from vllm import envs
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.platforms import current_platform


def is_aiter_mla_rope_flush_cache_fusion_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_MLA
)


@dataclass
Expand Down Expand Up @@ -105,6 +115,7 @@ def __init__(
)

self.prefix = prefix
self.mla_attn.impl.rotary_emb = self.rotary_emb

def forward_native(
self,
Expand Down Expand Up @@ -148,9 +159,10 @@ def forward_native(
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)

q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if not is_aiter_mla_rope_flush_cache_fusion_enabled():
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)

if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
Expand All @@ -159,6 +171,9 @@ def forward_native(
q,
kv_c_normed,
k_pe,
positions=(
positions if is_aiter_mla_rope_flush_cache_fusion_enabled() else None
),
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)
return self.o_proj(attn_out)[0]
Expand Down
Loading