|
4 | 4 | import torch |
5 | 5 |
|
6 | 6 | from vllm.platforms import current_platform |
7 | | -from vllm.vllm_flash_attn import (flash_attn_varlen_func, |
8 | | - flash_attn_with_kvcache) |
| 7 | +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, |
| 8 | + flash_attn_varlen_func, |
| 9 | + flash_attn_with_kvcache, |
| 10 | + is_fa_version_supported) |
9 | 11 |
|
10 | 12 | NUM_HEADS = [(4, 4), (8, 2), (16, 2)] |
11 | 13 | HEAD_SIZES = [128, 256] |
@@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv( |
95 | 97 | fa_version: int, |
96 | 98 | ) -> None: |
97 | 99 | torch.set_default_device("cuda") |
98 | | - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) |
99 | | - or torch.cuda.get_device_capability() == (8, 9)): |
100 | | - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " |
101 | | - "insufficient shared memory for some shapes") |
| 100 | + if not is_fa_version_supported(fa_version): |
| 101 | + pytest.skip(f"Flash attention version {fa_version} not supported due " |
| 102 | + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") |
102 | 103 |
|
103 | 104 | current_platform.seed_everything(0) |
104 | 105 | num_seqs = len(kv_lens) |
@@ -182,11 +183,9 @@ def test_varlen_with_paged_kv( |
182 | 183 | fa_version: int, |
183 | 184 | ) -> None: |
184 | 185 | torch.set_default_device("cuda") |
185 | | - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) |
186 | | - or torch.cuda.get_device_capability() == (8, 9)): |
187 | | - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " |
188 | | - "insufficient shared memory for some shapes") |
189 | | - |
| 186 | + if not is_fa_version_supported(fa_version): |
| 187 | + pytest.skip(f"Flash attention version {fa_version} not supported due " |
| 188 | + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") |
190 | 189 | current_platform.seed_everything(0) |
191 | 190 | num_seqs = len(seq_lens) |
192 | 191 | query_lens = [x[0] for x in seq_lens] |
|
0 commit comments