|
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import torch.nn.functional as F |
11 | | -import triton |
12 | | -from vllm.model_executor.layers.fla.ops.layernorm_guard import \ |
13 | | - layer_norm_fwd_kernel |
| 11 | +from vllm.triton_utils import tl, triton |
| 12 | + |
| 13 | + |
| 14 | +@triton.heuristics({ |
| 15 | + "HAS_BIAS": lambda args: args["B"] is not None, |
| 16 | + "HAS_Z": lambda args: args["Z"] is not None, |
| 17 | +}) |
| 18 | +@triton.jit |
| 19 | +def layer_norm_fwd_kernel( |
| 20 | + X, # pointer to the input |
| 21 | + Y, # pointer to the output |
| 22 | + W, # pointer to the weights |
| 23 | + B, # pointer to the biases |
| 24 | + Z, # pointer to the other branch |
| 25 | + Mean, # pointer to the mean |
| 26 | + Rstd, # pointer to the 1/std |
| 27 | + stride_x_row, # how much to increase the pointer when moving by 1 row |
| 28 | + stride_y_row, |
| 29 | + stride_z_row, |
| 30 | + M, # number of rows in X |
| 31 | + N, # number of columns in X |
| 32 | + eps, # epsilon to avoid division by zero |
| 33 | + BLOCK_N: tl.constexpr, |
| 34 | + HAS_BIAS: tl.constexpr, |
| 35 | + HAS_Z: tl.constexpr, |
| 36 | + NORM_BEFORE_GATE: tl.constexpr, |
| 37 | + IS_RMS_NORM: tl.constexpr, |
| 38 | +): |
| 39 | + # Map the program id to the row of X and Y it should compute. |
| 40 | + row = tl.program_id(0) |
| 41 | + group = tl.program_id(1) |
| 42 | + X += row * stride_x_row + group * N |
| 43 | + Y += row * stride_y_row + group * N |
| 44 | + if HAS_Z: |
| 45 | + Z += row * stride_z_row + group * N |
| 46 | + if not IS_RMS_NORM: |
| 47 | + Mean += group * M |
| 48 | + Rstd += group * M |
| 49 | + W += group * N |
| 50 | + if HAS_BIAS: |
| 51 | + B += group * N |
| 52 | + # Compute mean and variance |
| 53 | + cols = tl.arange(0, BLOCK_N) |
| 54 | + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
| 55 | + if HAS_Z and not NORM_BEFORE_GATE: |
| 56 | + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) |
| 57 | + x *= z * tl.sigmoid(z) |
| 58 | + if not IS_RMS_NORM: |
| 59 | + mean = tl.sum(x, axis=0) / N |
| 60 | + tl.store(Mean + row, mean) |
| 61 | + xbar = tl.where(cols < N, x - mean, 0.0) |
| 62 | + var = tl.sum(xbar * xbar, axis=0) / N |
| 63 | + else: |
| 64 | + xbar = tl.where(cols < N, x, 0.0) |
| 65 | + var = tl.sum(xbar * xbar, axis=0) / N |
| 66 | + rstd = 1 / tl.sqrt(var + eps) |
| 67 | + tl.store(Rstd + row, rstd) |
| 68 | + # Normalize and apply linear transformation |
| 69 | + mask = cols < N |
| 70 | + w = tl.load(W + cols, mask=mask).to(tl.float32) |
| 71 | + if HAS_BIAS: |
| 72 | + b = tl.load(B + cols, mask=mask).to(tl.float32) |
| 73 | + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd |
| 74 | + y = x_hat * w + b if HAS_BIAS else x_hat * w |
| 75 | + if HAS_Z and NORM_BEFORE_GATE: |
| 76 | + z = tl.load(Z + cols, mask=mask).to(tl.float32) |
| 77 | + y *= z * tl.sigmoid(z) |
| 78 | + # Write output |
| 79 | + tl.store(Y + cols, y, mask=mask) |
14 | 80 |
|
15 | 81 |
|
16 | 82 | def _layer_norm_fwd( |
|
0 commit comments