Skip to content

Conversation

@bohnstingl
Copy link
Contributor

@bohnstingl bohnstingl commented Oct 4, 2025

Purpose

This PR reduces overhead from the pure PyTorch implementation of the SSD initial state extraction and state caching by using two additional Triton kernels.
It is based on #26222

cc @tdoublep @s3woz

Test Plan

The behavior is transparent to the user and thus the existing tests can be reused.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

tdoublep and others added 17 commits October 4, 2025 01:29
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Integrated initial state and cached state extraction
Removed pure PyTorch overhead

Signed-off-by: Thomas Ortner <[email protected]>
@github-actions
Copy link

github-actions bot commented Oct 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 4, 2025
@bohnstingl bohnstingl changed the title [V1] Mamba2 kernel integration [V1] Mamba2 SSD kernel integration Oct 4, 2025
@bohnstingl bohnstingl marked this pull request as ready for review October 7, 2025 09:46
@bohnstingl bohnstingl requested a review from tdoublep as a code owner October 7, 2025 09:46
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +1021 to +1032
n_blocks_to_fill = (
block_idx_last_scheduled_token - block_idx_first_scheduled_token
)

grid = lambda META: (
nseq
* (
n_blocks_to_fill.max() + 1
), # The +1 is for the last state that is always stored
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Convert block counts to Python ints before Triton launch

When prefix caching is enabled, _state_cache_fwd computes the launch grid using nseq * (n_blocks_to_fill.max() + 1) without converting the tensor to a Python scalar. block_idx_last_scheduled_token and block_idx_first_scheduled_token live on CUDA, so n_blocks_to_fill.max() is a CUDA tensor; Triton expects grid entries to be plain integers. At runtime this will raise TypeError: only one element tensors can be converted to Python scalars before the kernel ever launches, breaking all prefix-cached prefills. Cast the max value to a Python int (e.g. .max().item()) before constructing the grid.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants