Skip to content

Commit ecc7122

Browse files
committed
[BugFix] Fix Qwen3-next because of vllm #26207
Signed-off-by: Icey <[email protected]>
1 parent 16cb3cc commit ecc7122

File tree

1 file changed

+69
-3
lines changed

1 file changed

+69
-3
lines changed

vllm_ascend/ops/fla.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,75 @@
88

99
import torch
1010
import torch.nn.functional as F
11-
import triton
12-
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
13-
layer_norm_fwd_kernel
11+
from vllm.triton_utils import tl, triton
12+
13+
14+
@triton.heuristics({
15+
"HAS_BIAS": lambda args: args["B"] is not None,
16+
"HAS_Z": lambda args: args["Z"] is not None,
17+
})
18+
@triton.jit
19+
def layer_norm_fwd_kernel(
20+
X, # pointer to the input
21+
Y, # pointer to the output
22+
W, # pointer to the weights
23+
B, # pointer to the biases
24+
Z, # pointer to the other branch
25+
Mean, # pointer to the mean
26+
Rstd, # pointer to the 1/std
27+
stride_x_row, # how much to increase the pointer when moving by 1 row
28+
stride_y_row,
29+
stride_z_row,
30+
M, # number of rows in X
31+
N, # number of columns in X
32+
eps, # epsilon to avoid division by zero
33+
BLOCK_N: tl.constexpr,
34+
HAS_BIAS: tl.constexpr,
35+
HAS_Z: tl.constexpr,
36+
NORM_BEFORE_GATE: tl.constexpr,
37+
IS_RMS_NORM: tl.constexpr,
38+
):
39+
# Map the program id to the row of X and Y it should compute.
40+
row = tl.program_id(0)
41+
group = tl.program_id(1)
42+
X += row * stride_x_row + group * N
43+
Y += row * stride_y_row + group * N
44+
if HAS_Z:
45+
Z += row * stride_z_row + group * N
46+
if not IS_RMS_NORM:
47+
Mean += group * M
48+
Rstd += group * M
49+
W += group * N
50+
if HAS_BIAS:
51+
B += group * N
52+
# Compute mean and variance
53+
cols = tl.arange(0, BLOCK_N)
54+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
55+
if HAS_Z and not NORM_BEFORE_GATE:
56+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
57+
x *= z * tl.sigmoid(z)
58+
if not IS_RMS_NORM:
59+
mean = tl.sum(x, axis=0) / N
60+
tl.store(Mean + row, mean)
61+
xbar = tl.where(cols < N, x - mean, 0.0)
62+
var = tl.sum(xbar * xbar, axis=0) / N
63+
else:
64+
xbar = tl.where(cols < N, x, 0.0)
65+
var = tl.sum(xbar * xbar, axis=0) / N
66+
rstd = 1 / tl.sqrt(var + eps)
67+
tl.store(Rstd + row, rstd)
68+
# Normalize and apply linear transformation
69+
mask = cols < N
70+
w = tl.load(W + cols, mask=mask).to(tl.float32)
71+
if HAS_BIAS:
72+
b = tl.load(B + cols, mask=mask).to(tl.float32)
73+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
74+
y = x_hat * w + b if HAS_BIAS else x_hat * w
75+
if HAS_Z and NORM_BEFORE_GATE:
76+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
77+
y *= z * tl.sigmoid(z)
78+
# Write output
79+
tl.store(Y + cols, y, mask=mask)
1480

1581

1682
def _layer_norm_fwd(

0 commit comments

Comments
 (0)