-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[WIP][Rocm] Add rope and flush kvscache fusion for aiter mla. #26383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a kernel fusion for rotary embedding and KV cache flushing for AITer MLA on ROCm. It also refactors AITer-specific code from common.py into rocm_aiter_mla.py for better code organization. While the refactoring is a good improvement, I've found a critical issue in the new forward implementation within rocm_aiter_mla.py. The query tensor q is not correctly updated after applying the fused rotary embedding, which will lead to incorrect attention outputs. My review includes a detailed comment and a code suggestion to fix this bug.
| decode_q = q[:num_decode_tokens] | ||
|
|
||
| prefill_q = q[num_decode_tokens:] | ||
| prefill_k_pe = k_pe[num_decode_tokens:] | ||
| prefill_k_c_normed = k_c_normed[num_decode_tokens:] | ||
|
|
||
| # write the latent and rope to kv cache | ||
| if kv_cache.numel() > 0: | ||
| if is_aiter_mla_rope_flush_cache_fusion_enabled(): | ||
| assert positions is not None | ||
| cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1) | ||
| is_neox = self.rotary_emb.is_neox_style | ||
| q_nope, q_pe = q.split( | ||
| [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 | ||
| ) | ||
| q = fused_qk_rope_cat_and_cache_mla( | ||
| q_nope, | ||
| q_pe, | ||
| k_c_normed.unsqueeze(1), | ||
| k_pe, | ||
| kv_cache, | ||
| attn_metadata.slot_mapping.flatten(), | ||
| positions, | ||
| cos, | ||
| sin, | ||
| layer._k_scale, | ||
| is_neox, | ||
| ) | ||
| else: | ||
| ops.concat_and_cache_mla( | ||
| k_c_normed, | ||
| k_pe.squeeze(1), | ||
| kv_cache, | ||
| attn_metadata.slot_mapping.flatten(), | ||
| kv_cache_dtype=self.kv_cache_dtype, | ||
| scale=layer._k_scale, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The q tensor is updated by fused_qk_rope_cat_and_cache_mla which applies rotary embeddings when fusion is enabled. However, decode_q and prefill_q are sliced from the original q tensor before this fusion function is called. The updated q is then never used for slicing, which means the rotary embeddings are not applied to the query tensor. This will lead to incorrect model outputs.
The slicing of decode_q and prefill_q should happen after the q tensor has been processed by the fusion kernel.
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if is_aiter_mla_rope_flush_cache_fusion_enabled():
assert positions is not None
cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1)
is_neox = self.rotary_emb.is_neox_style
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
q = fused_qk_rope_cat_and_cache_mla(
q_nope,
q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
attn_metadata.slot_mapping.flatten(),
positions,
cos,
sin,
layer._k_scale,
is_neox,
)
else:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| self.impl.forward( | ||
| self, query, key, value, self_kv_cache, attn_metadata, output=output | ||
| self, | ||
| query, | ||
| key, | ||
| value, | ||
| self_kv_cache, | ||
| attn_metadata, | ||
| positions=positions, | ||
| output=output, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid passing
positions to backends that do not accept it
The new positions kwarg is passed unconditionally to every attention backend (added in the call to self.impl.forward(...) here). Most existing implementations—including FlashAttention, Triton, ROCm, etc.—still define forward(self, layer, query, key, value, kv_cache, attn_metadata, output=None, ...) without a positions parameter, so any inference that reaches these paths will immediately raise TypeError: forward() got an unexpected keyword argument 'positions'. Either gate the kwarg to the ROCm Aiter path or update all backends to accept it.
Useful? React with 👍 / 👎.
| if is_global_first_rank(): | ||
| pre_compilation_list = tqdm( | ||
| pre_compilation_list, | ||
| desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix tqdm import before using progress bar
Here process_weights_after_loading wraps the pre-compilation loop with tqdm(...), but the file imports the module via import tqdm. Since modules are not callable, this call will throw TypeError: 'module' object is not callable whenever FP8-BMM precompilation runs on the first global rank (VLLM_ROCM_USE_AITER_FP8BMM enabled), aborting model initialization. Import the function (from tqdm import tqdm) or call tqdm.tqdm(...) instead.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| self.impl.forward( | ||
| self, query, key, value, self_kv_cache, attn_metadata, output=output | ||
| self, | ||
| query, | ||
| key, | ||
| value, | ||
| self_kv_cache, | ||
| attn_metadata, | ||
| positions=positions, | ||
| output=output, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid passing
positions to backends that do not accept it
The new positions kwarg is passed unconditionally to every attention backend (added in the call to self.impl.forward(...) here). Most existing implementations—including FlashAttention, Triton, ROCm, etc.—still define forward(self, layer, query, key, value, kv_cache, attn_metadata, output=None, ...) without a positions parameter, so any inference that reaches these paths will immediately raise TypeError: forward() got an unexpected keyword argument 'positions'. Either gate the kwarg to the ROCm Aiter path or update all backends to accept it.
Useful? React with 👍 / 👎.
| if is_global_first_rank(): | ||
| pre_compilation_list = tqdm( | ||
| pre_compilation_list, | ||
| desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix tqdm import before using progress bar
Here process_weights_after_loading wraps the pre-compilation loop with tqdm(...), but the file imports the module via import tqdm. Since modules are not callable, this call will throw TypeError: 'module' object is not callable whenever FP8-BMM precompilation runs on the first global rank (VLLM_ROCM_USE_AITER_FP8BMM enabled), aborting model initialization. Import the function (from tqdm import tqdm) or call tqdm.tqdm(...) instead.
Useful? React with 👍 / 👎.
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing rope info into attention, we're gonna be extracting kvcache from attention and then fusing there. Please have a look at #24678 and linked PRs
This PR is a draft to add kernel fusion of rotary_embedding and flush kv cache for aite mla.
common.pytorocm_aiter_mla.py