Skip to content

Conversation

@learning-chip
Copy link
Contributor

@learning-chip learning-chip commented Aug 18, 2024

Follow-up #39 (comment)

Eventually will allow the e2e mamba2 example #39 to run without the dependency on the original mamba_ssm repo.

This PR adds unit tests to ensure equivalence between {chunk_simple_gla/torch_simple_gla/torch_simple_gla_recurrent under fla.ops.simple_gla of this repository} and {mamba_chunk_scan_combined/ssd_minimal_discrete inside mamba_ssm repository}.

Unit test output from this PR:

$ pytest -v ./test_simple_gla_for_mamba2.py
====================================================== test session starts ======================================================
collected 6 items                                                                                                               

test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float32-True] PASSED                                                    [ 16%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float32-False] PASSED                                                   [ 33%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float16-True] PASSED                                                    [ 50%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float16-False] PASSED                                                   [ 66%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[bfloat16-True] PASSED                                                   [ 83%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[bfloat16-False] PASSED                                                  [100%]

Differences between simple_gla kernel and "mamba2_ssd" kernel:

  • mamba2_ssd uses input/output layout [batch, seq, head, hidden], while simple_gla uses [batch, head, seq, hidden]
  • mamba2_ssd does not apply the attention-inspired scaling q * (DK ** -0.5)
  • mamba2_ssd takes an extra dt input for discretization, but this can be easily absorbed into the gating matrix A as did in mamba2 example
  • mamba2_ssd's fused kernel does not take time-varying A (though the minimal torch version does), probably because the time-dependence is expressed by dt, not A_t? simple_gla supports time-varying g directly.
  • mamba2_ssd uses "group query attention", but simple_gla (also other kernels in this repo?) always use the same number of heads for Q & K & V. For now, force the same number of heads in tests.

Ref Section 7.2 of Mamba-2 paper:
group_query

Todo:

FYI @DanFosing @yzhangcs @sustcsonglin

@yzhangcs
Copy link
Member

@learning-chip very cool contributions! I think it would be great if you add some benchmarks regarding simple_gla and mamba2 kernels like in https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_gla.py.

@yzhangcs
Copy link
Member

I will be working on GQA recently

@yzhangcs yzhangcs marked this pull request as ready for review August 18, 2024 17:40
@yzhangcs yzhangcs merged commit 9aa2480 into fla-org:main Aug 18, 2024
@learning-chip
Copy link
Contributor Author

add some benchmarks regarding simple_gla and mamba2 kernels like in https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_gla.py.

Some quick results #50

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants