|
| 1 | +import math |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import mindspore as ms |
| 5 | +from mindspore import mint, ops |
| 6 | + |
| 7 | + |
| 8 | +def dispatch_attention_fn( |
| 9 | + query: ms.Tensor, |
| 10 | + key: ms.Tensor, |
| 11 | + value: ms.Tensor, |
| 12 | + attn_mask: Optional[ms.Tensor] = None, |
| 13 | + dropout_p: float = 0.0, |
| 14 | + is_causal: bool = False, |
| 15 | + scale: Optional[float] = None, |
| 16 | +): |
| 17 | + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| 18 | + # Note: PyTorch's SDPA and MindSpore's FA handle `attention_mask` slightly differently. |
| 19 | + # In PyTorch, if the mask is not boolean (e.g., float32 with 0/1 values), it is interpreted |
| 20 | + # as an additive bias: `attn_bias = attn_mask + attn_bias`. |
| 21 | + # This implicit branch may lead to issues if the pipeline mistakenly provides |
| 22 | + # a 0/1 float mask instead of a boolean mask. |
| 23 | + # While this behavior is consistent with HF Diffusers for now, |
| 24 | + # it may still be a potential bug source worth validating. |
| 25 | + if attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask: |
| 26 | + L, S = query.shape[-2], key.shape[-2] |
| 27 | + scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale |
| 28 | + attn_bias = mint.zeros((L, S), dtype=query.dtype) |
| 29 | + if is_causal: |
| 30 | + if attn_mask is not None: |
| 31 | + if attn_mask.dtype == ms.bool_: |
| 32 | + attn_mask = mint.logical_and(attn_mask, mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0)) |
| 33 | + else: |
| 34 | + attn_mask = attn_mask + mint.triu( |
| 35 | + mint.full((L, S), float("-inf"), dtype=attn_mask.dtype), diagonal=1 |
| 36 | + ) |
| 37 | + else: |
| 38 | + temp_mask = mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0) |
| 39 | + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
| 40 | + attn_bias = attn_bias.to(query.dtype) |
| 41 | + |
| 42 | + if attn_mask is not None: |
| 43 | + if attn_mask.dtype == ms.bool_: |
| 44 | + attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) |
| 45 | + else: |
| 46 | + attn_bias = attn_mask + attn_bias |
| 47 | + |
| 48 | + attn_weight = mint.matmul(query, key.swapaxes(-2, -1)) * scale_factor |
| 49 | + attn_weight += attn_bias |
| 50 | + attn_weight = mint.softmax(attn_weight, dim=-1) |
| 51 | + attn_weight = ops.dropout(attn_weight, dropout_p, training=True) |
| 52 | + return mint.matmul(attn_weight, value).permute(0, 2, 1, 3) |
| 53 | + |
| 54 | + if query.dtype in (ms.float16, ms.bfloat16): |
| 55 | + out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) |
| 56 | + else: |
| 57 | + out = flash_attention_op( |
| 58 | + query.to(ms.float16), |
| 59 | + key.to(ms.float16), |
| 60 | + value.to(ms.float16), |
| 61 | + attn_mask, |
| 62 | + keep_prob=1 - dropout_p, |
| 63 | + scale=scale, |
| 64 | + ).to(query.dtype) |
| 65 | + return out.permute(0, 2, 1, 3) |
| 66 | + |
| 67 | + |
| 68 | +def flash_attention_op( |
| 69 | + query: ms.Tensor, |
| 70 | + key: ms.Tensor, |
| 71 | + value: ms.Tensor, |
| 72 | + attn_mask: Optional[ms.Tensor] = None, |
| 73 | + keep_prob: float = 1.0, |
| 74 | + scale: Optional[float] = None, |
| 75 | +): |
| 76 | + # For most scenarios, qkv has been processed into a BNSD layout before sdp |
| 77 | + input_layout = "BNSD" |
| 78 | + head_num = query.shape[1] |
| 79 | + if scale is None: |
| 80 | + scale = query.shape[-1] ** (-0.5) |
| 81 | + |
| 82 | + # In case qkv is 3-dim after `head_to_batch_dim` |
| 83 | + if query.ndim == 3: |
| 84 | + input_layout = "BSH" |
| 85 | + head_num = 1 |
| 86 | + |
| 87 | + # process `attn_mask` as logic is different between PyTorch and Mindspore |
| 88 | + # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite |
| 89 | + if attn_mask is not None: |
| 90 | + attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() |
| 91 | + attn_mask = mint.broadcast_to( |
| 92 | + attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) |
| 93 | + )[:, :1, :, :] |
| 94 | + |
| 95 | + return ops.operations.nn_ops.FlashAttentionScore( |
| 96 | + head_num=head_num, keep_prob=keep_prob, scale_value=scale, input_layout=input_layout |
| 97 | + )(query, key, value, None, None, None, attn_mask)[3] |
0 commit comments