|
| 1 | +"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" |
| 2 | +import math |
| 3 | +from typing import List, Optional, Tuple, Union |
| 4 | + |
| 5 | +import mindspore as ms |
| 6 | +import mindspore.mint as mint |
| 7 | +import mindspore.ops as ops |
| 8 | +from mindspore import Parameter, ParameterTuple, Tensor |
| 9 | +from mindspore.experimental.optim.optimizer import Optimizer |
| 10 | + |
| 11 | +_muon_opt = ops.MultitypeFuncGraph("muon_opt") |
| 12 | + |
| 13 | + |
| 14 | +@_muon_opt.register( |
| 15 | + "Float", |
| 16 | + "Float", |
| 17 | + "Float", |
| 18 | + "Float", |
| 19 | + "Bool", |
| 20 | + "Int", |
| 21 | + "Float", |
| 22 | + "Tensor", |
| 23 | + "Tensor", |
| 24 | + "Tensor", |
| 25 | + "Tensor", |
| 26 | + "Tensor", |
| 27 | + "Tensor", |
| 28 | + "Float", |
| 29 | + "Bool", |
| 30 | +) |
| 31 | +def _update_run_op( |
| 32 | + mu: float, |
| 33 | + beta1: float, |
| 34 | + beta2: float, |
| 35 | + eps: float, |
| 36 | + nesterov: bool, |
| 37 | + ns_steps: int, |
| 38 | + weight_decay: float, |
| 39 | + lr: Parameter, |
| 40 | + denom: Parameter, |
| 41 | + param: Parameter, |
| 42 | + m: Parameter, |
| 43 | + v: Parameter, |
| 44 | + g: Tensor, |
| 45 | + ratio: float, |
| 46 | + use_muon: bool, |
| 47 | +) -> bool: |
| 48 | + if weight_decay != 0: |
| 49 | + param.mul_(1 - lr * weight_decay) |
| 50 | + |
| 51 | + if use_muon: |
| 52 | + m.mul_(mu).add_(g) |
| 53 | + if nesterov: |
| 54 | + g = g.add(m, alpha=mu) |
| 55 | + else: |
| 56 | + g = m |
| 57 | + g = zeropower_via_newtonschulz5(g, steps=ns_steps) |
| 58 | + param.add_(lr * g, alpha=-ratio) |
| 59 | + else: |
| 60 | + m_next = mint.lerp(g, m, beta1) |
| 61 | + v_next = mint.lerp(mint.square(g), v, beta2) |
| 62 | + g = m_next / (eps + mint.sqrt(v_next)) |
| 63 | + param.add_(-(lr / denom) * g) |
| 64 | + ops.assign(m, m_next) |
| 65 | + ops.assign(v, v_next) |
| 66 | + return True |
| 67 | + |
| 68 | + |
| 69 | +_qk_clip_opt = ops.MultitypeFuncGraph("qk_clip_opt") |
| 70 | + |
| 71 | + |
| 72 | +@_qk_clip_opt.register("Float", "Int", "Tensor", "Tensor", "Tensor") |
| 73 | +def _update_clip_op( |
| 74 | + clip_value: float, qk_nope_head_dim: int, qk: Tensor, q_b_projs: Parameter, kv_b_projs: Parameter |
| 75 | +) -> bool: |
| 76 | + qk = mint.transpose(qk, 0, 1).flatten(start_dim=1) |
| 77 | + qk_max, _ = mint.max(qk, dim=1) |
| 78 | + num_head = qk_max.shape[0] |
| 79 | + scale = mint.clip(clip_value / qk_max, max=1.0) |
| 80 | + scale = scale[:, None, None] |
| 81 | + scale_sqrt = mint.sqrt(scale) |
| 82 | + # clip Q projection |
| 83 | + outdim, _ = q_b_projs.shape |
| 84 | + head_dim = outdim // num_head |
| 85 | + scale_q_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1)) |
| 86 | + scale_q_b_rope = mint.tile(scale, (1, head_dim - qk_nope_head_dim, 1)) |
| 87 | + scale_q_b = mint.cat([scale_q_b_nope, scale_q_b_rope], dim=1) |
| 88 | + q_b_projs.mul_(scale_q_b.view(-1, 1)) |
| 89 | + # clip K projection |
| 90 | + outdim, _ = kv_b_projs.shape |
| 91 | + head_dim = outdim // num_head |
| 92 | + scale_kv_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1)) |
| 93 | + scale_kv_b_rope = mint.ones((num_head, head_dim - qk_nope_head_dim, 1), dtype=scale_sqrt.dtype) |
| 94 | + scale_kv_b = mint.cat([scale_kv_b_nope, scale_kv_b_rope], dim=1) |
| 95 | + kv_b_projs.mul_(scale_kv_b.view(-1, 1)) |
| 96 | + return True |
| 97 | + |
| 98 | + |
| 99 | +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: |
| 100 | + """ |
| 101 | + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a |
| 102 | + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose |
| 103 | + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at |
| 104 | + zero even beyond the point where the iteration no longer converges all the way to one everywhere |
| 105 | + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T |
| 106 | + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model |
| 107 | + performance at all relative to UV^T, where USV^T = G is the SVD. |
| 108 | + """ |
| 109 | + shape = G.shape |
| 110 | + |
| 111 | + if len(shape) > 2: |
| 112 | + G = G.view(G.shape[0], -1) |
| 113 | + assert len(shape) == 2 |
| 114 | + |
| 115 | + a, b, c = 3.4445, -4.7750, 2.0315 |
| 116 | + X = G.bfloat16() |
| 117 | + if G.shape[0] > G.shape[1]: |
| 118 | + X = mint.t(X) |
| 119 | + |
| 120 | + # Ensure spectral norm is at most 1 |
| 121 | + X = X / (mint.norm(X) + 1e-7) |
| 122 | + # Perform the NS iterations |
| 123 | + for _ in range(steps): |
| 124 | + A = mint.matmul(X, X.T) |
| 125 | + B = mint.addmm(A, A, A, beta=b, alpha=c) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng |
| 126 | + X = mint.addmm(X, B, X, beta=a) |
| 127 | + |
| 128 | + if G.shape[0] > G.shape[1]: |
| 129 | + X = mint.t(X) |
| 130 | + |
| 131 | + if len(shape) > 2: |
| 132 | + X = X.view(*shape) |
| 133 | + return X |
| 134 | + |
| 135 | + |
| 136 | +class Muon(Optimizer): |
| 137 | + """Following https://github.com/MoonshotAI/Moonlight""" |
| 138 | + |
| 139 | + def __init__( |
| 140 | + self, |
| 141 | + lr: Union[float, Tensor] = 1e-3, |
| 142 | + wd: float = 0.1, |
| 143 | + muon_params: Optional[List[Parameter]] = None, |
| 144 | + momentum: float = 0.95, |
| 145 | + nesterov: bool = True, |
| 146 | + ns_steps: int = 5, |
| 147 | + adamw_params: Optional[List[Parameter]] = None, |
| 148 | + adamw_betas: Tuple[float, float] = (0.9, 0.95), |
| 149 | + adamw_eps: float = 1e-8, |
| 150 | + clip_value: Optional[float] = 100.0, |
| 151 | + qk_nope_head_dim: int = 64, |
| 152 | + ) -> None: |
| 153 | + defaults = dict( |
| 154 | + lr=lr, |
| 155 | + wd=wd, |
| 156 | + momentum=momentum, |
| 157 | + nesterov=nesterov, |
| 158 | + ns_steps=ns_steps, |
| 159 | + adamw_betas=adamw_betas, |
| 160 | + adamw_eps=adamw_eps, |
| 161 | + ) |
| 162 | + params = list(muon_params) |
| 163 | + adamw_params = list(adamw_params) if adamw_params is not None else [] |
| 164 | + params.extend(adamw_params) |
| 165 | + super().__init__(params, defaults) |
| 166 | + self.clip_value = clip_value |
| 167 | + self.qk_nope_head_dim = qk_nope_head_dim |
| 168 | + # Sort parameters into those for which we will use Muon, and those for which we will not |
| 169 | + use_muon = list() |
| 170 | + for p in muon_params: |
| 171 | + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer |
| 172 | + assert p.ndim >= 2, p.ndim |
| 173 | + use_muon.append(True) |
| 174 | + |
| 175 | + for p in adamw_params: |
| 176 | + # Do not use Muon for parameters in adamw_params |
| 177 | + use_muon.append(False) |
| 178 | + self.use_muon = tuple(use_muon) |
| 179 | + |
| 180 | + self.exp_avg = self.parameters.clone("exp_avg", init="zeros") |
| 181 | + self.exp_avg_sq = ParameterTuple( |
| 182 | + [ |
| 183 | + ( |
| 184 | + Parameter(mint.zeros(x.shape, dtype=x.dtype), name="exp_avg_sq." + x.name) |
| 185 | + if not use_muon |
| 186 | + else Parameter([], name="exp_avg_sq." + x.name) |
| 187 | + ) |
| 188 | + for x, use_muon in zip(self.parameters, self.use_muon) |
| 189 | + ] |
| 190 | + ) |
| 191 | + |
| 192 | + self.lr_ratio = tuple([self._cal_lr_ratio(x, use_muon) for x, use_muon in zip(self.parameters, self.use_muon)]) |
| 193 | + |
| 194 | + self.state_step = Parameter(Tensor(0, dtype=ms.int32)) |
| 195 | + self.increase_tensor = Tensor(1, dtype=ms.int32) |
| 196 | + self.denom = Parameter(Tensor(1.0, dtype=ms.float32)) |
| 197 | + |
| 198 | + if self.clip_value is not None: |
| 199 | + # group the Q and KV projection first for easier updating in QK-clip |
| 200 | + # TODO: it should be extracted from optimizer as extra inputs |
| 201 | + q_b_projs = [] |
| 202 | + kv_b_projs = [] |
| 203 | + for x in self.parameters: |
| 204 | + if x.name.endswith("q_b_proj.weight"): |
| 205 | + layer_idx = int(x.name.split(".")[2]) |
| 206 | + q_b_projs.append((layer_idx, x)) |
| 207 | + elif x.name.endswith("kv_b_proj.weight"): |
| 208 | + layer_idx = int(x.name.split(".")[2]) |
| 209 | + kv_b_projs.append((layer_idx, x)) |
| 210 | + q_b_projs = sorted(q_b_projs, key=lambda x: x[0]) |
| 211 | + kv_b_projs = sorted(kv_b_projs, key=lambda x: x[0]) |
| 212 | + self.q_b_projs = ParameterTuple([x[1] for x in q_b_projs]) |
| 213 | + self.kv_b_projs = ParameterTuple([x[1] for x in kv_b_projs]) |
| 214 | + assert len(self.q_b_projs) > 0 and len(self.kv_b_projs) > 0 |
| 215 | + |
| 216 | + def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2) -> float: |
| 217 | + if not use_muon: |
| 218 | + return 1.0 |
| 219 | + |
| 220 | + A, B = param.shape[:2] |
| 221 | + # We adjust the learning rate and weight decay based on the size of the parameter matrix |
| 222 | + # as describted in the paper |
| 223 | + adjusted_ratio = rms_scale * math.sqrt(max(A, B)) |
| 224 | + return adjusted_ratio |
| 225 | + |
| 226 | + @ms.jit(jit_level="O1") |
| 227 | + def muon( |
| 228 | + self, |
| 229 | + momentum: float, |
| 230 | + beta1: float, |
| 231 | + beta2: float, |
| 232 | + eps: float, |
| 233 | + nesterov: bool, |
| 234 | + ns_steps: int, |
| 235 | + weight_decay: float, |
| 236 | + lr: Parameter, |
| 237 | + gradients: Tuple[Tensor, ...], |
| 238 | + ratio: Tuple[float, ...], |
| 239 | + use_muon: Tuple[bool, ...], |
| 240 | + start_id: int, |
| 241 | + end_id: int, |
| 242 | + ) -> bool: |
| 243 | + bias_correction1 = 1 - beta1**self.state_step |
| 244 | + bias_correction2 = 1 - beta2**self.state_step |
| 245 | + ops.assign(self.denom, bias_correction1 / bias_correction2**0.5) |
| 246 | + |
| 247 | + optim_result = self.hyper_map( |
| 248 | + ops.partial( |
| 249 | + _muon_opt, |
| 250 | + momentum, |
| 251 | + beta1, |
| 252 | + beta2, |
| 253 | + eps, |
| 254 | + nesterov, |
| 255 | + ns_steps, |
| 256 | + weight_decay, |
| 257 | + lr, |
| 258 | + self.denom, |
| 259 | + ), |
| 260 | + self.parameters[start_id:end_id], |
| 261 | + self.exp_avg[start_id:end_id], |
| 262 | + self.exp_avg_sq[start_id:end_id], |
| 263 | + gradients[start_id:end_id], |
| 264 | + ratio[start_id:end_id], |
| 265 | + use_muon[start_id:end_id], |
| 266 | + ) |
| 267 | + return optim_result |
| 268 | + |
| 269 | + @ms.jit(jit_level="O1") |
| 270 | + def qk_clip(self, qk_products: Tuple[Tensor, ...]) -> bool: |
| 271 | + optim_result = self.hyper_map( |
| 272 | + ops.partial(_qk_clip_opt, self.clip_value, self.qk_nope_head_dim), |
| 273 | + qk_products, |
| 274 | + self.q_b_projs, |
| 275 | + self.kv_b_projs, |
| 276 | + ) |
| 277 | + return optim_result |
| 278 | + |
| 279 | + def construct(self, gradients: Tuple[Tensor, ...], qk_products: Optional[Tuple[Tensor, ...]] = None) -> bool: |
| 280 | + if self.clip_value is not None: |
| 281 | + assert qk_products is not None |
| 282 | + |
| 283 | + self.state_step.add_(self.increase_tensor) |
| 284 | + for group_id, group in enumerate(self.param_groups): |
| 285 | + beta1, beta2 = group["adamw_betas"] |
| 286 | + start_id = self.group_start_id[group_id] |
| 287 | + end_id = self.group_start_id[group_id + 1] |
| 288 | + |
| 289 | + self.muon( |
| 290 | + group["momentum"], |
| 291 | + beta1, |
| 292 | + beta2, |
| 293 | + group["adamw_eps"], |
| 294 | + group["nesterov"], |
| 295 | + group["ns_steps"], |
| 296 | + group["weight_decay"], |
| 297 | + group["lr"], |
| 298 | + gradients, |
| 299 | + self.lr_ratio, |
| 300 | + self.use_muon, |
| 301 | + start_id, |
| 302 | + end_id, |
| 303 | + ) |
| 304 | + |
| 305 | + if self.clip_value is None: |
| 306 | + return True |
| 307 | + else: |
| 308 | + optim_result = self.qk_clip(qk_products) |
| 309 | + return optim_result |
0 commit comments