88
99import torch
1010import torch .nn .functional as F
11- import triton
12- from vllm .model_executor .layers .fla .ops .layernorm_guard import \
13- layer_norm_fwd_kernel
11+ from vllm .triton_utils import tl , triton
12+
13+ MAX_CORES = 65535
14+
15+
16+ @triton .heuristics ({
17+ "HAS_BIAS" : lambda args : args ["B" ] is not None ,
18+ "HAS_Z" : lambda args : args ["Z" ] is not None ,
19+ })
20+ @triton .jit
21+ def layer_norm_fwd_kernel (
22+ X , # pointer to the input
23+ Y , # pointer to the output
24+ W , # pointer to the weights
25+ B , # pointer to the biases
26+ Z , # pointer to the other branch
27+ Mean , # pointer to the mean
28+ Rstd , # pointer to the 1/std
29+ stride_x_row , # how much to increase the pointer when moving by 1 row
30+ stride_y_row ,
31+ stride_z_row ,
32+ M , # number of rows in X_base
33+ N , # number of columns in X_base
34+ eps , # epsilon to avoid division by zero
35+ BLOCK_N : tl .constexpr ,
36+ HAS_BIAS : tl .constexpr ,
37+ HAS_Z : tl .constexpr ,
38+ NORM_BEFORE_GATE : tl .constexpr ,
39+ IS_RMS_NORM : tl .constexpr ,
40+ N_CORES : tl .constexpr ,
41+ ):
42+ # Map the program id to the row of X_base and Y_base it should compute.
43+ row = tl .program_id (0 )
44+ group = tl .program_id (1 )
45+
46+ BLOCK_ROWS = M if M < N_CORES else N_CORES
47+ n_iters = M // BLOCK_ROWS
48+ remain = M % BLOCK_ROWS
49+ if row < remain :
50+ n_iters = n_iters + 1
51+
52+ for i in tl .range (n_iters ):
53+ X_base = X + (i * BLOCK_ROWS *
54+ stride_x_row ) + row * stride_x_row + group * N
55+ Y_base = Y + (i * BLOCK_ROWS *
56+ stride_y_row ) + row * stride_y_row + group * N
57+ if HAS_Z :
58+ Z_base = Z + (i * BLOCK_ROWS *
59+ stride_z_row ) + row * stride_z_row + group * N
60+ if not IS_RMS_NORM :
61+ Mean_base = Mean + (i * BLOCK_ROWS ) + group * M
62+ Rstd_base = Rstd + (i * BLOCK_ROWS ) + group * M
63+ W_base = W + group * N
64+ if HAS_BIAS :
65+ B_base = B + group * N
66+ # Compute mean and variance
67+ cols = tl .arange (0 , BLOCK_N )
68+ x = tl .load (X_base + cols , mask = cols < N , other = 0. ).to (tl .float32 )
69+ if HAS_Z and not NORM_BEFORE_GATE :
70+ z = tl .load (Z_base + cols , mask = cols < N ).to (tl .float32 )
71+ x *= z * tl .sigmoid (z )
72+ if not IS_RMS_NORM :
73+ mean = tl .sum (x , axis = 0 ) / N
74+ tl .store (Mean_base + row , mean )
75+ xbar = tl .where (cols < N , x - mean , 0. )
76+ var = tl .sum (xbar * xbar , axis = 0 ) / N
77+ else :
78+ xbar = tl .where (cols < N , x , 0. )
79+ var = tl .sum (xbar * xbar , axis = 0 ) / N
80+ rstd = 1 / tl .sqrt (var + eps )
81+ tl .store (Rstd_base + row , rstd )
82+ # Normalize and apply linear transformation
83+ mask = cols < N
84+ w = tl .load (W_base + cols , mask = mask ).to (tl .float32 )
85+ if HAS_BIAS :
86+ b = tl .load (B_base + cols , mask = mask ).to (tl .float32 )
87+ x_hat = (x - mean ) * rstd if not IS_RMS_NORM else x * rstd
88+ y = x_hat * w + b if HAS_BIAS else x_hat * w
89+ if HAS_Z and NORM_BEFORE_GATE :
90+ z = tl .load (Z_base + cols , mask = mask ).to (tl .float32 )
91+ y *= z * tl .sigmoid (z )
92+ # Write output
93+ tl .store (Y_base + cols , y , mask = mask )
1494
1595
1696def _layer_norm_fwd (
@@ -55,7 +135,7 @@ def _layer_norm_fwd(
55135 "This layer norm doesn't support feature dim >= 64KB." )
56136 # heuristics for number of warps
57137 num_warps = min (max (BLOCK_N // 256 , 1 ), 8 )
58- grid = (M , ngroups )
138+ grid = (M if M < MAX_CORES else MAX_CORES , ngroups )
59139 with torch .npu .device (x .device .index ):
60140 layer_norm_fwd_kernel [grid ](
61141 x ,
@@ -74,6 +154,7 @@ def _layer_norm_fwd(
74154 BLOCK_N = BLOCK_N ,
75155 NORM_BEFORE_GATE = norm_before_gate ,
76156 IS_RMS_NORM = is_rms_norm ,
157+ N_CORES = MAX_CORES ,
77158 num_warps = num_warps ,
78159 )
79160 return out , mean , rstd
0 commit comments