Skip to content

Conversation

@learning-chip
Copy link
Contributor

@learning-chip learning-chip commented Aug 18, 2024

Follow-up #49

Amazingly, it seems like chunk_simple_gla is much faster than mamba_chunk_scan_combined:

$ python ./benchmark_simple_gla_vs_mamba2.py

Performance:
         T  chunk_simple_gla  mamba2_ssd
0     64.0          0.084992    0.840208
1    128.0          0.100352    0.847920
2    256.0          0.100368    0.848896
3    512.0          0.174080    0.873472
4   1024.0          0.399360    0.880208
5   2048.0          0.776352    1.596416
6   4096.0          1.526784    3.160064
7   8192.0          3.067904    6.251520
8  16384.0          6.220800   12.452864

Performance

I left many TODO and NOTE in the benchmark scripts, including:

  • Testing more input shapes
  • Tuning block size
  • analyze impact of input memory layout

More importantly:

  • more detailed profiling to understand why exactly it is faster.

Maybe mamba-2 kernel incurs more memory IO (less "fused")? And why the short-sequence performance (T<256) differs by so much?

@yzhangcs
Copy link
Member

@learning-chip Great job! Appreciate your quick actions.

@yzhangcs yzhangcs merged commit c60ada3 into fla-org:main Aug 18, 2024
@sustcsonglin
Copy link
Collaborator

@learning-chip Mamba2’s official kernel involves three main steps: 1) computation of each chunk’s last hidden state, 2) recurrence at the chunk level, and 3) output computation.

For steps 1) and 2), it stores/loads the hidden state in FP32, which incurs significant I/O costs.

FLA’s implementation fuses steps 1) and 2), avoids materializing the FP32 hidden state after step 1) and stores only the BF16 hidden state after 2), thus reducing I/O costs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants