Skip to content

Commit 9981d67

Browse files
committed
[BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next
1 parent daa4dd0 commit 9981d67

File tree

1 file changed

+85
-4
lines changed

1 file changed

+85
-4
lines changed

vllm_ascend/ops/fla.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,89 @@
88

99
import torch
1010
import torch.nn.functional as F
11-
import triton
12-
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
13-
layer_norm_fwd_kernel
11+
from vllm.triton_utils import tl, triton
12+
13+
MAX_CORES = 65535
14+
15+
16+
@triton.heuristics({
17+
"HAS_BIAS": lambda args: args["B"] is not None,
18+
"HAS_Z": lambda args: args["Z"] is not None,
19+
})
20+
@triton.jit
21+
def layer_norm_fwd_kernel(
22+
X, # pointer to the input
23+
Y, # pointer to the output
24+
W, # pointer to the weights
25+
B, # pointer to the biases
26+
Z, # pointer to the other branch
27+
Mean, # pointer to the mean
28+
Rstd, # pointer to the 1/std
29+
stride_x_row, # how much to increase the pointer when moving by 1 row
30+
stride_y_row,
31+
stride_z_row,
32+
M, # number of rows in X_base
33+
N, # number of columns in X_base
34+
eps, # epsilon to avoid division by zero
35+
BLOCK_N: tl.constexpr,
36+
HAS_BIAS: tl.constexpr,
37+
HAS_Z: tl.constexpr,
38+
NORM_BEFORE_GATE: tl.constexpr,
39+
IS_RMS_NORM: tl.constexpr,
40+
N_CORES: tl.constexpr,
41+
):
42+
# Map the program id to the row of X_base and Y_base it should compute.
43+
row = tl.program_id(0)
44+
group = tl.program_id(1)
45+
46+
BLOCK_ROWS = M if M < N_CORES else N_CORES
47+
n_iters = M // BLOCK_ROWS
48+
remain = M % BLOCK_ROWS
49+
if row < remain:
50+
n_iters = n_iters + 1
51+
52+
for i in tl.range(n_iters):
53+
X_base = X + (i * BLOCK_ROWS *
54+
stride_x_row) + row * stride_x_row + group * N
55+
Y_base = Y + (i * BLOCK_ROWS *
56+
stride_y_row) + row * stride_y_row + group * N
57+
if HAS_Z:
58+
Z_base = Z + (i * BLOCK_ROWS *
59+
stride_z_row) + row * stride_z_row + group * N
60+
if not IS_RMS_NORM:
61+
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
62+
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
63+
W_base = W + group * N
64+
if HAS_BIAS:
65+
B_base = B + group * N
66+
# Compute mean and variance
67+
cols = tl.arange(0, BLOCK_N)
68+
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
69+
if HAS_Z and not NORM_BEFORE_GATE:
70+
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
71+
x *= z * tl.sigmoid(z)
72+
if not IS_RMS_NORM:
73+
mean = tl.sum(x, axis=0) / N
74+
tl.store(Mean_base + row, mean)
75+
xbar = tl.where(cols < N, x - mean, 0.)
76+
var = tl.sum(xbar * xbar, axis=0) / N
77+
else:
78+
xbar = tl.where(cols < N, x, 0.)
79+
var = tl.sum(xbar * xbar, axis=0) / N
80+
rstd = 1 / tl.sqrt(var + eps)
81+
tl.store(Rstd_base + row, rstd)
82+
# Normalize and apply linear transformation
83+
mask = cols < N
84+
w = tl.load(W_base + cols, mask=mask).to(tl.float32)
85+
if HAS_BIAS:
86+
b = tl.load(B_base + cols, mask=mask).to(tl.float32)
87+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
88+
y = x_hat * w + b if HAS_BIAS else x_hat * w
89+
if HAS_Z and NORM_BEFORE_GATE:
90+
z = tl.load(Z_base + cols, mask=mask).to(tl.float32)
91+
y *= z * tl.sigmoid(z)
92+
# Write output
93+
tl.store(Y_base + cols, y, mask=mask)
1494

1595

1696
def _layer_norm_fwd(
@@ -55,7 +135,7 @@ def _layer_norm_fwd(
55135
"This layer norm doesn't support feature dim >= 64KB.")
56136
# heuristics for number of warps
57137
num_warps = min(max(BLOCK_N // 256, 1), 8)
58-
grid = (M, ngroups)
138+
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
59139
with torch.npu.device(x.device.index):
60140
layer_norm_fwd_kernel[grid](
61141
x,
@@ -74,6 +154,7 @@ def _layer_norm_fwd(
74154
BLOCK_N=BLOCK_N,
75155
NORM_BEFORE_GATE=norm_before_gate,
76156
IS_RMS_NORM=is_rms_norm,
157+
N_CORES=MAX_CORES,
77158
num_warps=num_warps,
78159
)
79160
return out, mean, rstd

0 commit comments

Comments
 (0)