Skip to content

Commit de6c2c4

Browse files
authored
Muon with QK-Clip support (#1198)
* fix qwen2 init * add Muon Optimizer * add test script * move test scripts * improve speed * improve speed 2 * add MLA attention * fix error * Qwen2 add MLA & Muon support QK-clip * turn on shuffle
1 parent 6ec66e8 commit de6c2c4

File tree

4 files changed

+1002
-9
lines changed

4 files changed

+1002
-9
lines changed

mindone/trainers/muon.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py"""
2+
import math
3+
from typing import List, Optional, Tuple, Union
4+
5+
import mindspore as ms
6+
import mindspore.mint as mint
7+
import mindspore.ops as ops
8+
from mindspore import Parameter, ParameterTuple, Tensor
9+
from mindspore.experimental.optim.optimizer import Optimizer
10+
11+
_muon_opt = ops.MultitypeFuncGraph("muon_opt")
12+
13+
14+
@_muon_opt.register(
15+
"Float",
16+
"Float",
17+
"Float",
18+
"Float",
19+
"Bool",
20+
"Int",
21+
"Float",
22+
"Tensor",
23+
"Tensor",
24+
"Tensor",
25+
"Tensor",
26+
"Tensor",
27+
"Tensor",
28+
"Float",
29+
"Bool",
30+
)
31+
def _update_run_op(
32+
mu: float,
33+
beta1: float,
34+
beta2: float,
35+
eps: float,
36+
nesterov: bool,
37+
ns_steps: int,
38+
weight_decay: float,
39+
lr: Parameter,
40+
denom: Parameter,
41+
param: Parameter,
42+
m: Parameter,
43+
v: Parameter,
44+
g: Tensor,
45+
ratio: float,
46+
use_muon: bool,
47+
) -> bool:
48+
if weight_decay != 0:
49+
param.mul_(1 - lr * weight_decay)
50+
51+
if use_muon:
52+
m.mul_(mu).add_(g)
53+
if nesterov:
54+
g = g.add(m, alpha=mu)
55+
else:
56+
g = m
57+
g = zeropower_via_newtonschulz5(g, steps=ns_steps)
58+
param.add_(lr * g, alpha=-ratio)
59+
else:
60+
m_next = mint.lerp(g, m, beta1)
61+
v_next = mint.lerp(mint.square(g), v, beta2)
62+
g = m_next / (eps + mint.sqrt(v_next))
63+
param.add_(-(lr / denom) * g)
64+
ops.assign(m, m_next)
65+
ops.assign(v, v_next)
66+
return True
67+
68+
69+
_qk_clip_opt = ops.MultitypeFuncGraph("qk_clip_opt")
70+
71+
72+
@_qk_clip_opt.register("Float", "Int", "Tensor", "Tensor", "Tensor")
73+
def _update_clip_op(
74+
clip_value: float, qk_nope_head_dim: int, qk: Tensor, q_b_projs: Parameter, kv_b_projs: Parameter
75+
) -> bool:
76+
qk = mint.transpose(qk, 0, 1).flatten(start_dim=1)
77+
qk_max, _ = mint.max(qk, dim=1)
78+
num_head = qk_max.shape[0]
79+
scale = mint.clip(clip_value / qk_max, max=1.0)
80+
scale = scale[:, None, None]
81+
scale_sqrt = mint.sqrt(scale)
82+
# clip Q projection
83+
outdim, _ = q_b_projs.shape
84+
head_dim = outdim // num_head
85+
scale_q_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1))
86+
scale_q_b_rope = mint.tile(scale, (1, head_dim - qk_nope_head_dim, 1))
87+
scale_q_b = mint.cat([scale_q_b_nope, scale_q_b_rope], dim=1)
88+
q_b_projs.mul_(scale_q_b.view(-1, 1))
89+
# clip K projection
90+
outdim, _ = kv_b_projs.shape
91+
head_dim = outdim // num_head
92+
scale_kv_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1))
93+
scale_kv_b_rope = mint.ones((num_head, head_dim - qk_nope_head_dim, 1), dtype=scale_sqrt.dtype)
94+
scale_kv_b = mint.cat([scale_kv_b_nope, scale_kv_b_rope], dim=1)
95+
kv_b_projs.mul_(scale_kv_b.view(-1, 1))
96+
return True
97+
98+
99+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
100+
"""
101+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
102+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
103+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
104+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
105+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
106+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
107+
performance at all relative to UV^T, where USV^T = G is the SVD.
108+
"""
109+
shape = G.shape
110+
111+
if len(shape) > 2:
112+
G = G.view(G.shape[0], -1)
113+
assert len(shape) == 2
114+
115+
a, b, c = 3.4445, -4.7750, 2.0315
116+
X = G.bfloat16()
117+
if G.shape[0] > G.shape[1]:
118+
X = mint.t(X)
119+
120+
# Ensure spectral norm is at most 1
121+
X = X / (mint.norm(X) + 1e-7)
122+
# Perform the NS iterations
123+
for _ in range(steps):
124+
A = mint.matmul(X, X.T)
125+
B = mint.addmm(A, A, A, beta=b, alpha=c) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
126+
X = mint.addmm(X, B, X, beta=a)
127+
128+
if G.shape[0] > G.shape[1]:
129+
X = mint.t(X)
130+
131+
if len(shape) > 2:
132+
X = X.view(*shape)
133+
return X
134+
135+
136+
class Muon(Optimizer):
137+
"""Following https://github.com/MoonshotAI/Moonlight"""
138+
139+
def __init__(
140+
self,
141+
lr: Union[float, Tensor] = 1e-3,
142+
wd: float = 0.1,
143+
muon_params: Optional[List[Parameter]] = None,
144+
momentum: float = 0.95,
145+
nesterov: bool = True,
146+
ns_steps: int = 5,
147+
adamw_params: Optional[List[Parameter]] = None,
148+
adamw_betas: Tuple[float, float] = (0.9, 0.95),
149+
adamw_eps: float = 1e-8,
150+
clip_value: Optional[float] = 100.0,
151+
qk_nope_head_dim: int = 64,
152+
) -> None:
153+
defaults = dict(
154+
lr=lr,
155+
wd=wd,
156+
momentum=momentum,
157+
nesterov=nesterov,
158+
ns_steps=ns_steps,
159+
adamw_betas=adamw_betas,
160+
adamw_eps=adamw_eps,
161+
)
162+
params = list(muon_params)
163+
adamw_params = list(adamw_params) if adamw_params is not None else []
164+
params.extend(adamw_params)
165+
super().__init__(params, defaults)
166+
self.clip_value = clip_value
167+
self.qk_nope_head_dim = qk_nope_head_dim
168+
# Sort parameters into those for which we will use Muon, and those for which we will not
169+
use_muon = list()
170+
for p in muon_params:
171+
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
172+
assert p.ndim >= 2, p.ndim
173+
use_muon.append(True)
174+
175+
for p in adamw_params:
176+
# Do not use Muon for parameters in adamw_params
177+
use_muon.append(False)
178+
self.use_muon = tuple(use_muon)
179+
180+
self.exp_avg = self.parameters.clone("exp_avg", init="zeros")
181+
self.exp_avg_sq = ParameterTuple(
182+
[
183+
(
184+
Parameter(mint.zeros(x.shape, dtype=x.dtype), name="exp_avg_sq." + x.name)
185+
if not use_muon
186+
else Parameter([], name="exp_avg_sq." + x.name)
187+
)
188+
for x, use_muon in zip(self.parameters, self.use_muon)
189+
]
190+
)
191+
192+
self.lr_ratio = tuple([self._cal_lr_ratio(x, use_muon) for x, use_muon in zip(self.parameters, self.use_muon)])
193+
194+
self.state_step = Parameter(Tensor(0, dtype=ms.int32))
195+
self.increase_tensor = Tensor(1, dtype=ms.int32)
196+
self.denom = Parameter(Tensor(1.0, dtype=ms.float32))
197+
198+
if self.clip_value is not None:
199+
# group the Q and KV projection first for easier updating in QK-clip
200+
# TODO: it should be extracted from optimizer as extra inputs
201+
q_b_projs = []
202+
kv_b_projs = []
203+
for x in self.parameters:
204+
if x.name.endswith("q_b_proj.weight"):
205+
layer_idx = int(x.name.split(".")[2])
206+
q_b_projs.append((layer_idx, x))
207+
elif x.name.endswith("kv_b_proj.weight"):
208+
layer_idx = int(x.name.split(".")[2])
209+
kv_b_projs.append((layer_idx, x))
210+
q_b_projs = sorted(q_b_projs, key=lambda x: x[0])
211+
kv_b_projs = sorted(kv_b_projs, key=lambda x: x[0])
212+
self.q_b_projs = ParameterTuple([x[1] for x in q_b_projs])
213+
self.kv_b_projs = ParameterTuple([x[1] for x in kv_b_projs])
214+
assert len(self.q_b_projs) > 0 and len(self.kv_b_projs) > 0
215+
216+
def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2) -> float:
217+
if not use_muon:
218+
return 1.0
219+
220+
A, B = param.shape[:2]
221+
# We adjust the learning rate and weight decay based on the size of the parameter matrix
222+
# as describted in the paper
223+
adjusted_ratio = rms_scale * math.sqrt(max(A, B))
224+
return adjusted_ratio
225+
226+
@ms.jit(jit_level="O1")
227+
def muon(
228+
self,
229+
momentum: float,
230+
beta1: float,
231+
beta2: float,
232+
eps: float,
233+
nesterov: bool,
234+
ns_steps: int,
235+
weight_decay: float,
236+
lr: Parameter,
237+
gradients: Tuple[Tensor, ...],
238+
ratio: Tuple[float, ...],
239+
use_muon: Tuple[bool, ...],
240+
start_id: int,
241+
end_id: int,
242+
) -> bool:
243+
bias_correction1 = 1 - beta1**self.state_step
244+
bias_correction2 = 1 - beta2**self.state_step
245+
ops.assign(self.denom, bias_correction1 / bias_correction2**0.5)
246+
247+
optim_result = self.hyper_map(
248+
ops.partial(
249+
_muon_opt,
250+
momentum,
251+
beta1,
252+
beta2,
253+
eps,
254+
nesterov,
255+
ns_steps,
256+
weight_decay,
257+
lr,
258+
self.denom,
259+
),
260+
self.parameters[start_id:end_id],
261+
self.exp_avg[start_id:end_id],
262+
self.exp_avg_sq[start_id:end_id],
263+
gradients[start_id:end_id],
264+
ratio[start_id:end_id],
265+
use_muon[start_id:end_id],
266+
)
267+
return optim_result
268+
269+
@ms.jit(jit_level="O1")
270+
def qk_clip(self, qk_products: Tuple[Tensor, ...]) -> bool:
271+
optim_result = self.hyper_map(
272+
ops.partial(_qk_clip_opt, self.clip_value, self.qk_nope_head_dim),
273+
qk_products,
274+
self.q_b_projs,
275+
self.kv_b_projs,
276+
)
277+
return optim_result
278+
279+
def construct(self, gradients: Tuple[Tensor, ...], qk_products: Optional[Tuple[Tensor, ...]] = None) -> bool:
280+
if self.clip_value is not None:
281+
assert qk_products is not None
282+
283+
self.state_step.add_(self.increase_tensor)
284+
for group_id, group in enumerate(self.param_groups):
285+
beta1, beta2 = group["adamw_betas"]
286+
start_id = self.group_start_id[group_id]
287+
end_id = self.group_start_id[group_id + 1]
288+
289+
self.muon(
290+
group["momentum"],
291+
beta1,
292+
beta2,
293+
group["adamw_eps"],
294+
group["nesterov"],
295+
group["ns_steps"],
296+
group["weight_decay"],
297+
group["lr"],
298+
gradients,
299+
self.lr_ratio,
300+
self.use_muon,
301+
start_id,
302+
end_id,
303+
)
304+
305+
if self.clip_value is None:
306+
return True
307+
else:
308+
optim_result = self.qk_clip(qk_products)
309+
return optim_result

0 commit comments

Comments
 (0)