You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* add seq_idx and fa kwargs
* update tests
* docs and grad ckpt support
* fmt
* better names
* test_raise_missing_padding_free_kwarg_errs
* + seq_idx in doc strings
* padding free training docs
* add link to pr plots
* raise err on attn_mask with padding free
* rm raising missing padding free err test
* BambaFlashAttentionKwargs
* run modular util for modular_granitemoehybrid.py
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
78
+
packages, and the following arguments must be passed to the model in addition to `input_ids` and
79
+
`labels`:
80
+
*`position_ids: torch.LongTensor`: the position index of each token in each sequence.
81
+
*`seq_idx: torch.IntTensor`: the index of each sequence in the batch.
82
+
* Each of the [`FlashAttentionKwargs`]
83
+
*`cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
84
+
*`cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
85
+
*`max_length_q: int`: the longest query length in the batch.
86
+
*`max_length_k: int`: the longest key length in the batch.
87
+
88
+
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
89
+
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
90
+
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
91
+
for additional information.
92
+
93
+
66
94
[[autodoc]] BambaForCausalLM
67
95
- forward
68
96
69
-
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
97
+
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
0 commit comments