|
| 1 | +# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +import math |
| 4 | +import random |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | +import triton |
| 9 | + |
| 10 | +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, |
| 11 | + get_mla_metadata, |
| 12 | + is_flashmla_supported) |
| 13 | + |
| 14 | + |
| 15 | +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: |
| 16 | + x, y = x.double(), y.double() |
| 17 | + cos_diff = 1 - 2 * (x * y).sum().item() / max( |
| 18 | + (x * x + y * y).sum().item(), 1e-12) |
| 19 | + assert cos_diff < 1e-5 |
| 20 | + |
| 21 | +FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ |
| 22 | + if not is_flashmla_supported()[0] else "FlashMLA is supported" |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.skipif(not is_flashmla_supported()[0], |
| 26 | + reason=FLASH_MLA_UNSUPPORTED_REASON) |
| 27 | +@pytest.mark.parametrize("b", [128]) |
| 28 | +@pytest.mark.parametrize("s_q", [1, 2]) |
| 29 | +@pytest.mark.parametrize("mean_sk", [4096, 8192]) |
| 30 | +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) |
| 31 | +@pytest.mark.parametrize("h_kv", [1]) |
| 32 | +@pytest.mark.parametrize("d", [576]) |
| 33 | +@pytest.mark.parametrize("dv", [512]) |
| 34 | +@pytest.mark.parametrize("block_size", [64]) |
| 35 | +@pytest.mark.parametrize("causal", [True]) |
| 36 | +@pytest.mark.parametrize("varlen", [False, True]) |
| 37 | +@torch.inference_mode() |
| 38 | +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, |
| 39 | + varlen): |
| 40 | + # TODO: parametrize using pytest |
| 41 | + dtype = torch.bfloat16 |
| 42 | + device = torch.device("cuda:0") |
| 43 | + torch.set_default_dtype(dtype) |
| 44 | + torch.set_default_device(device) |
| 45 | + torch.cuda.set_device(device) |
| 46 | + torch.manual_seed(0) |
| 47 | + random.seed(0) |
| 48 | + |
| 49 | + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " |
| 50 | + f"{d=}, {dv=}, {causal=}, {varlen=}") |
| 51 | + |
| 52 | + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) |
| 53 | + if varlen: |
| 54 | + for i in range(b): |
| 55 | + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), |
| 56 | + s_q) |
| 57 | + total_seqlens = cache_seqlens.sum().item() |
| 58 | + max_seqlen = cache_seqlens.max().item() |
| 59 | + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 |
| 60 | + |
| 61 | + q = torch.randn(b, s_q, h_q, d) |
| 62 | + block_table = torch.arange(b * max_seqlen_pad // block_size, |
| 63 | + dtype=torch.int32).view( |
| 64 | + b, max_seqlen_pad // block_size) |
| 65 | + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) |
| 66 | + for i in range(b): |
| 67 | + blocked_k.view(b, max_seqlen_pad, h_kv, |
| 68 | + d)[i, cache_seqlens[i].item():] = float("nan") |
| 69 | + blocked_v = blocked_k[..., :dv] |
| 70 | + |
| 71 | + tile_scheduler_metadata, num_splits = get_mla_metadata( |
| 72 | + cache_seqlens, s_q * h_q // h_kv, h_kv) |
| 73 | + |
| 74 | + def flash_mla(): |
| 75 | + return flash_mla_with_kvcache( |
| 76 | + q, |
| 77 | + blocked_k, |
| 78 | + block_table, |
| 79 | + cache_seqlens, |
| 80 | + dv, |
| 81 | + tile_scheduler_metadata, |
| 82 | + num_splits, |
| 83 | + causal=causal, |
| 84 | + ) |
| 85 | + |
| 86 | + def scaled_dot_product_attention(query, key, value, is_causal=False): |
| 87 | + query = query.float() |
| 88 | + key = key.float() |
| 89 | + value = value.float() |
| 90 | + key = key.repeat_interleave(h_q // h_kv, dim=0) |
| 91 | + value = value.repeat_interleave(h_q // h_kv, dim=0) |
| 92 | + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) |
| 93 | + if is_causal: |
| 94 | + s_q = query.shape[-2] |
| 95 | + s_k = key.shape[-2] |
| 96 | + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) |
| 97 | + temp_mask = torch.ones(s_q, s_k, |
| 98 | + dtype=torch.bool).tril(diagonal=s_k - s_q) |
| 99 | + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
| 100 | + attn_bias.to(query.dtype) |
| 101 | + attn_weight += attn_bias |
| 102 | + lse = attn_weight.logsumexp(dim=-1) |
| 103 | + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) |
| 104 | + return attn_weight @ value, lse |
| 105 | + |
| 106 | + def ref_mla(): |
| 107 | + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) |
| 108 | + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) |
| 109 | + for i in range(b): |
| 110 | + begin = i * max_seqlen_pad |
| 111 | + end = begin + cache_seqlens[i] |
| 112 | + ref_O, LSE = scaled_dot_product_attention( |
| 113 | + q[i].transpose(0, 1), |
| 114 | + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), |
| 115 | + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), |
| 116 | + is_causal=causal, |
| 117 | + ) |
| 118 | + out[i] = ref_O.transpose(0, 1) |
| 119 | + lse[i] = LSE |
| 120 | + return out, lse |
| 121 | + |
| 122 | + out_flash, lse_flash = flash_mla() |
| 123 | + out_torch, lse_torch = ref_mla() |
| 124 | + cal_diff(out_flash, out_torch, "out") |
| 125 | + cal_diff(lse_flash, lse_torch, "lse") |
| 126 | + |
| 127 | + t = triton.testing.do_bench(flash_mla, fast_flush=False) |
| 128 | + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 |
| 129 | + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + |
| 130 | + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) |
| 131 | + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " |
| 132 | + f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") |
0 commit comments