Skip to content

Commit 36152fb

Browse files
committed
Add qwen 2.5
1 parent d99970b commit 36152fb

File tree

4 files changed

+128
-8
lines changed

4 files changed

+128
-8
lines changed

examples/models/llama/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
175175
self.max_batch_size = args.max_batch_size
176176
self.max_context_len = args.max_context_len
177177
self.dim = args.dim
178-
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
179-
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
180-
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
178+
# TODO: parametrize bias for attention and feedforward.
179+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=True)
180+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True)
181+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True)
181182
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
182183

183184
self.layer_id = layer_id

examples/models/llama/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(self, **kwargs):
150150
input_prune_map=input_prune_map,
151151
output_prune_map=output_prune_map,
152152
enable_dynamic_shape=self.enable_dynamic_shape,
153+
use_hf_rope=True,
153154
**params,
154155
)
155156

@@ -170,7 +171,7 @@ def __init__(self, **kwargs):
170171

171172
# Within the device="meta" context, tensors that are created do not carry data.
172173
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173-
with torch.device("meta"):
174+
with torch.device("cpu"):
174175
self.model_ = Transformer(model_args)
175176

176177
if "int8" in str(checkpoint_path):

examples/models/llama/rope.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def apply_rotary_emb_to_k(
114114
return xk_out.type_as(xk)
115115

116116

117+
# Wrap apply_rotary_emb in a module to enable it to be module swapped out.
117118
class RotaryEmbedding(torch.nn.Module):
118119
def __init__(self):
119120
super().__init__()
@@ -209,18 +210,66 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
209210
return k_embed
210211

211212

213+
# ======================= Qwen2 Implementation ========================
214+
215+
216+
def qwen_precompute_freqs_cis(dim: int, end: int, theta: float = 1_000_000.0):
217+
"""
218+
Precompute frequency tensor for Qwen2-style RoPE.
219+
"""
220+
freqs = 1.0 / (
221+
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
222+
)
223+
t = torch.arange(end, device=freqs.device)
224+
freqs = torch.outer(t, freqs).float()
225+
freqs_cos = torch.cos(freqs)
226+
freqs_sin = torch.sin(freqs)
227+
return freqs_cos, freqs_sin
228+
229+
230+
def qwen_apply_rotary_emb(
231+
q: torch.Tensor, k: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
232+
) -> Tuple[torch.Tensor, torch.Tensor]:
233+
"""
234+
Apply Qwen2-style RoPE to query and key tensors.
235+
"""
236+
def rotate_half(x):
237+
"""Rotates half the hidden dims of the input."""
238+
x1 = x[..., : x.shape[-1] // 2]
239+
x2 = x[..., x.shape[-1] // 2 :]
240+
return torch.cat((-x2, x1), dim=-1)
241+
242+
# Reshape cos and sin for broadcasting
243+
cos = freqs_cos.unsqueeze(1) # [seq_len, 1, head_dim]
244+
sin = freqs_sin.unsqueeze(1) # [seq_len, 1, head_dim]
245+
246+
# Apply rotation
247+
q_embed = (q * cos) + (rotate_half(q) * sin)
248+
k_embed = (k * cos) + (rotate_half(k) * sin)
249+
return q_embed, k_embed
250+
251+
212252
class Rope(torch.nn.Module):
213253
def __init__(self, params: ModelArgs):
214254
super().__init__()
215255
self.params = params
256+
257+
# Choose the appropriate RoPE implementation
216258
if self.params.use_hf_rope:
217259
self.precompute_freqs_cis = hf_precompute_freqs_cis
260+
self.apply_rotary_emb = hf_apply_rotary_emb
261+
# elif self.params.use_qwen_rope:
262+
# self.precompute_freqs_cis = qwen_precompute_freqs_cis
263+
# self.apply_rotary_emb = qwen_apply_rotary_emb
218264
else:
219265
self.precompute_freqs_cis = partial(
220266
precompute_freqs_cis,
221267
use_scaled=self.params.use_scaled_rope,
222268
scale_factor=self.params.rope_scale_factor,
223269
)
270+
self.apply_rotary_emb = RotaryEmbedding()
271+
272+
# Precompute frequencies
224273
freqs_cos, freqs_sin = self.precompute_freqs_cis(
225274
self.params.head_dim,
226275
(
@@ -232,10 +281,6 @@ def __init__(self, params: ModelArgs):
232281
)
233282
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
234283
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
235-
if self.params.use_hf_rope:
236-
self.apply_rotary_emb = hf_apply_rotary_emb
237-
else:
238-
self.apply_rotary_emb = RotaryEmbedding()
239284

240285
def forward(
241286
self,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Dict
2+
3+
from torchtune.training import FullModelHFCheckpointer
4+
# from torchtune.models import convert_weights
5+
from torchtune.models.convert_weights import get_mapped_key
6+
import torch
7+
8+
# Standard _FROM_META weight mapping from TorchTune + additional bias weight mappings.
9+
_QWEN_2_FROM_META = {
10+
"tok_embeddings.weight": "tok_embeddings.weight",
11+
"norm.weight": "norm.scale",
12+
"output.weight": "output.weight",
13+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
14+
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
15+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
16+
"layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias",
17+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
18+
"layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias",
19+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
20+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
21+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
22+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
23+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
24+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
25+
26+
}
27+
28+
def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29+
"""
30+
Convert a state dict from torchtune's format to Meta's format. This function
31+
doesn't handle any sharding or splitting of state dicts. It follows the
32+
state_dict IN -> state_dict OUT pattern.
33+
34+
Args:
35+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
36+
37+
Returns:
38+
Dict[str, torch.Tensor]: State dict in Meta's format.
39+
"""
40+
converted_state_dict = {}
41+
inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()}
42+
43+
for key, value in state_dict.items():
44+
new_key = get_mapped_key(key, inverted_mapping_dict)
45+
converted_state_dict[new_key] = value
46+
47+
return converted_state_dict
48+
49+
# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
50+
checkpointer = FullModelHFCheckpointer(
51+
checkpoint_dir='/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/',
52+
checkpoint_files=['model.safetensors'],
53+
output_dir='.' ,
54+
model_type='QWEN2'
55+
)
56+
57+
print("Loading checkpoint")
58+
sd = checkpointer.load_checkpoint()
59+
60+
print("HF weights:")
61+
for weight in sd["model"].keys():
62+
print(weight)
63+
print()
64+
65+
# Convert from TorchTune to Meta (PyTorch native)
66+
sd = qwen_2_tune_to_meta(sd['model'])
67+
68+
print("Meta weights:")
69+
for weight in sd.keys():
70+
print(weight)
71+
72+
print("Saving checkpoint")
73+
torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")

0 commit comments

Comments
 (0)