Replace mamba2 mamba_chunk_scan_combined triton kernel by simple_gla triton kernel
#49
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Follow-up #39 (comment)
Eventually will allow the e2e mamba2 example #39 to run without the dependency on the original mamba_ssm repo.
This PR adds unit tests to ensure equivalence between {
chunk_simple_gla/torch_simple_gla/torch_simple_gla_recurrentunderfla.ops.simple_glaof this repository} and {mamba_chunk_scan_combined/ssd_minimal_discreteinsidemamba_ssmrepository}.Unit test output from this PR:
Differences between
simple_glakernel and "mamba2_ssd" kernel:[batch, seq, head, hidden], while simple_gla uses[batch, head, seq, hidden]q * (DK ** -0.5)dtinput for discretization, but this can be easily absorbed into the gating matrixAas did in mamba2 exampleA(though the minimal torch version does), probably because the time-dependence is expressed bydt, notA_t?simple_glasupports time-varyinggdirectly.simple_gla(also other kernels in this repo?) always use the same number of heads for Q & K & V. For now, force the same number of heads in tests.Ref Section 7.2 of Mamba-2 paper:

Todo:
simple_glakernel (Mamba-Codestral usesn_groups=8)FYI @DanFosing @yzhangcs @sustcsonglin