Skip to content

Conversation

@charlifu
Copy link
Contributor

@charlifu charlifu commented Oct 7, 2025

This PR is a draft to add kernel fusion of rotary_embedding and flush kv cache for aite mla.

  • add position tensor argument to the attention layer and attention custom op.
  • move all aiter-specific code in common.py to rocm_aiter_mla.py

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Oct 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a kernel fusion for rotary embedding and KV cache flushing for AITer MLA on ROCm. It also refactors AITer-specific code from common.py into rocm_aiter_mla.py for better code organization. While the refactoring is a good improvement, I've found a critical issue in the new forward implementation within rocm_aiter_mla.py. The query tensor q is not correctly updated after applying the fused rotary embedding, which will lead to incorrect attention outputs. My review includes a detailed comment and a code suggestion to fix this bug.

Comment on lines +481 to +517
decode_q = q[:num_decode_tokens]

prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if is_aiter_mla_rope_flush_cache_fusion_enabled():
assert positions is not None
cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1)
is_neox = self.rotary_emb.is_neox_style
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
q = fused_qk_rope_cat_and_cache_mla(
q_nope,
q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
attn_metadata.slot_mapping.flatten(),
positions,
cos,
sin,
layer._k_scale,
is_neox,
)
else:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The q tensor is updated by fused_qk_rope_cat_and_cache_mla which applies rotary embeddings when fusion is enabled. However, decode_q and prefill_q are sliced from the original q tensor before this fusion function is called. The updated q is then never used for slicing, which means the rotary embeddings are not applied to the query tensor. This will lead to incorrect model outputs.

The slicing of decode_q and prefill_q should happen after the q tensor has been processed by the fusion kernel.

        prefill_k_pe = k_pe[num_decode_tokens:]
        prefill_k_c_normed = k_c_normed[num_decode_tokens:]

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            if is_aiter_mla_rope_flush_cache_fusion_enabled():
                assert positions is not None
                cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1)
                is_neox = self.rotary_emb.is_neox_style
                q_nope, q_pe = q.split(
                    [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
                )
                q = fused_qk_rope_cat_and_cache_mla(
                    q_nope,
                    q_pe,
                    k_c_normed.unsqueeze(1),
                    k_pe,
                    kv_cache,
                    attn_metadata.slot_mapping.flatten(),
                    positions,
                    cos,
                    sin,
                    layer._k_scale,
                    is_neox,
                )
            else:
                ops.concat_and_cache_mla(
                    k_c_normed,
                    k_pe.squeeze(1),
                    kv_cache,
                    attn_metadata.slot_mapping.flatten(),
                    kv_cache_dtype=self.kv_cache_dtype,
                    scale=layer._k_scale,
                )

        decode_q = q[:num_decode_tokens]

        prefill_q = q[num_decode_tokens:]

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 372 to 380
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
positions=positions,
output=output,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Avoid passing positions to backends that do not accept it

The new positions kwarg is passed unconditionally to every attention backend (added in the call to self.impl.forward(...) here). Most existing implementations—including FlashAttention, Triton, ROCm, etc.—still define forward(self, layer, query, key, value, kv_cache, attn_metadata, output=None, ...) without a positions parameter, so any inference that reaches these paths will immediately raise TypeError: forward() got an unexpected keyword argument 'positions'. Either gate the kwarg to the ROCm Aiter path or update all backends to accept it.

Useful? React with 👍 / 👎.

Comment on lines +349 to +352
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix tqdm import before using progress bar

Here process_weights_after_loading wraps the pre-compilation loop with tqdm(...), but the file imports the module via import tqdm. Since modules are not callable, this call will throw TypeError: 'module' object is not callable whenever FP8-BMM precompilation runs on the first global rank (VLLM_ROCM_USE_AITER_FP8BMM enabled), aborting model initialization. Import the function (from tqdm import tqdm) or call tqdm.tqdm(...) instead.

Useful? React with 👍 / 👎.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 372 to 380
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
positions=positions,
output=output,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Avoid passing positions to backends that do not accept it

The new positions kwarg is passed unconditionally to every attention backend (added in the call to self.impl.forward(...) here). Most existing implementations—including FlashAttention, Triton, ROCm, etc.—still define forward(self, layer, query, key, value, kv_cache, attn_metadata, output=None, ...) without a positions parameter, so any inference that reaches these paths will immediately raise TypeError: forward() got an unexpected keyword argument 'positions'. Either gate the kwarg to the ROCm Aiter path or update all backends to accept it.

Useful? React with 👍 / 👎.

Comment on lines +349 to +352
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix tqdm import before using progress bar

Here process_weights_after_loading wraps the pre-compilation loop with tqdm(...), but the file imports the module via import tqdm. Since modules are not callable, this call will throw TypeError: 'module' object is not callable whenever FP8-BMM precompilation runs on the first global rank (VLLM_ROCM_USE_AITER_FP8BMM enabled), aborting model initialization. Import the function (from tqdm import tqdm) or call tqdm.tqdm(...) instead.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing rope info into attention, we're gonna be extracting kvcache from attention and then fusing there. Please have a look at #24678 and linked PRs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants