-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[V1] Mamba2 SSD kernel integration #26235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[V1] Mamba2 SSD kernel integration #26235
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Integrated initial state and cached state extraction Removed pure PyTorch overhead Signed-off-by: Thomas Ortner <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
…ng_ssdkernel Signed-off-by: Thomas Ortner <[email protected]>
…a2_prefix_caching_ssdkernel Signed-off-by: Thomas Ortner <[email protected]>
…a2_prefix_caching_ssdkernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| n_blocks_to_fill = ( | ||
| block_idx_last_scheduled_token - block_idx_first_scheduled_token | ||
| ) | ||
|
|
||
| grid = lambda META: ( | ||
| nseq | ||
| * ( | ||
| n_blocks_to_fill.max() + 1 | ||
| ), # The +1 is for the last state that is always stored | ||
| triton.cdiv(nheads, META["BLOCK_SIZE_H"]), | ||
| triton.cdiv(headdim, META["BLOCK_SIZE_M"]) | ||
| * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert block counts to Python ints before Triton launch
When prefix caching is enabled, _state_cache_fwd computes the launch grid using nseq * (n_blocks_to_fill.max() + 1) without converting the tensor to a Python scalar. block_idx_last_scheduled_token and block_idx_first_scheduled_token live on CUDA, so n_blocks_to_fill.max() is a CUDA tensor; Triton expects grid entries to be plain integers. At runtime this will raise TypeError: only one element tensors can be converted to Python scalars before the kernel ever launches, breaking all prefix-cached prefills. Cast the max value to a Python int (e.g. .max().item()) before constructing the grid.
Useful? React with 👍 / 👎.
…a1_prefix_caching_ssdkernel
Purpose
This PR reduces overhead from the pure PyTorch implementation of the SSD initial state extraction and state caching by using two additional Triton kernels.
It is based on #26222
cc @tdoublep @s3woz
Test Plan
The behavior is transparent to the user and thus the existing tests can be reused.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.