@@ -23,6 +23,7 @@ class TensorNameMap:
2323 "model.embedding" , # mamba-qbert
2424 "backbone.embedding" , # mamba
2525 "backbone.embeddings" , # mamba-hf
26+ "transformer.in_out_embed" , # Grok
2627 ),
2728
2829 # Token type embeddings
@@ -66,6 +67,7 @@ class TensorNameMap:
6667 "lm_head.ln" , # phi2
6768 "model.norm_f" , # mamba-qbert
6869 "backbone.norm_f" , # mamba
70+ "transformer.rms_norm" , # Grok
6971 ),
7072
7173 # Rope frequencies
@@ -93,6 +95,7 @@ class TensorNameMap:
9395 "model.layers.{bid}.attention_norm" , # internlm2
9496 "model.layers.{bid}.norm" , # mamba-qbert
9597 "backbone.layers.{bid}.norm" , # mamba
98+ "transformer.decoder_layer.{bid}.rms_norm" , # Grok
9699 ),
97100
98101 # Attention norm 2
@@ -116,32 +119,35 @@ class TensorNameMap:
116119
117120 # Attention query
118121 MODEL_TENSOR .ATTN_Q : (
119- "model.layers.{bid}.self_attn.q_proj" , # llama-hf
120- "layers.{bid}.attention.wq" , # llama-pth
121- "encoder.layer.{bid}.attention.self.query" , # bert
122- "transformer.h.{bid}.attn.q_proj" , # gpt-j
123- "model.layers.layers.{bid}.self_attn.q_proj" , # plamo
124- "model.layers.{bid}.attention.wq" # internlm2
122+ "model.layers.{bid}.self_attn.q_proj" , # llama-hf
123+ "layers.{bid}.attention.wq" , # llama-pth
124+ "encoder.layer.{bid}.attention.self.query" , # bert
125+ "transformer.h.{bid}.attn.q_proj" , # gpt-j
126+ "model.layers.layers.{bid}.self_attn.q_proj" , # plamo
127+ "model.layers.{bid}.attention.wq" , # internlm2
128+ "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
125129 ),
126130
127131 # Attention key
128132 MODEL_TENSOR .ATTN_K : (
129- "model.layers.{bid}.self_attn.k_proj" , # llama-hf
130- "layers.{bid}.attention.wk" , # llama-pth
131- "encoder.layer.{bid}.attention.self.key" , # bert
132- "transformer.h.{bid}.attn.k_proj" , # gpt-j
133- "model.layers.layers.{bid}.self_attn.k_proj" , # plamo
134- "model.layers.{bid}.attention.wk" # internlm2
133+ "model.layers.{bid}.self_attn.k_proj" , # llama-hf
134+ "layers.{bid}.attention.wk" , # llama-pth
135+ "encoder.layer.{bid}.attention.self.key" , # bert
136+ "transformer.h.{bid}.attn.k_proj" , # gpt-j
137+ "model.layers.layers.{bid}.self_attn.k_proj" , # plamo
138+ "model.layers.{bid}.attention.wk" , # internlm2
139+ "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
135140 ),
136141
137142 # Attention value
138143 MODEL_TENSOR .ATTN_V : (
139- "model.layers.{bid}.self_attn.v_proj" , # llama-hf
140- "layers.{bid}.attention.wv" , # llama-pth
141- "encoder.layer.{bid}.attention.self.value" , # bert
142- "transformer.h.{bid}.attn.v_proj" , # gpt-j
143- "model.layers.layers.{bid}.self_attn.v_proj" , # plamo
144- "model.layers.{bid}.attention.wv" # internlm2
144+ "model.layers.{bid}.self_attn.v_proj" , # llama-hf
145+ "layers.{bid}.attention.wv" , # llama-pth
146+ "encoder.layer.{bid}.attention.self.value" , # bert
147+ "transformer.h.{bid}.attn.v_proj" , # gpt-j
148+ "model.layers.layers.{bid}.self_attn.v_proj" , # plamo
149+ "model.layers.{bid}.attention.wv" , # internlm2
150+ "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
145151 ),
146152
147153 # Attention output
@@ -162,12 +168,14 @@ class TensorNameMap:
162168 "model.layers.layers.{bid}.self_attn.o_proj" , # plamo
163169 "model.layers.{bid}.attention.wo" , # internlm2
164170 "encoder.layers.{bid}.attn.out_proj" , # nomic-bert
171+ "transformer.decoder_layer.{bid}.multi_head_attention.linear" # Grok
165172 ),
166173
167174 # Attention output norm
168175 MODEL_TENSOR .ATTN_OUT_NORM : (
169176 "encoder.layer.{bid}.attention.output.LayerNorm" , # bert
170177 "encoder.layers.{bid}.norm1" , # nomic-bert
178+ "transformer.decoder_layer.{bid}.rms_norm_1" , # Grok
171179 ),
172180
173181 # Rotary embeddings
@@ -190,11 +198,13 @@ class TensorNameMap:
190198 "model.layers.{bid}.ln2" , # yi
191199 "h.{bid}.ln_2" , # gpt2
192200 "model.layers.{bid}.ffn_norm" , # internlm2
201+ "transformer.decoder_layer.{bid}.rms_norm_2" , # Grok
193202 ),
194203
195204 MODEL_TENSOR .FFN_GATE_INP : (
196205 "layers.{bid}.feed_forward.gate" , # mixtral
197206 "model.layers.{bid}.block_sparse_moe.gate" , # mixtral
207+ "transformer.decoder_layer.{bid}.router" # Grok
198208 ),
199209
200210 # Feed-forward up
@@ -223,6 +233,7 @@ class TensorNameMap:
223233 MODEL_TENSOR .FFN_UP_EXP : (
224234 "layers.{bid}.feed_forward.experts.{xid}.w3" , # mixtral
225235 "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3" , # mixtral
236+ "transformer.decoder_layer.{bid}.moe.{xid}.linear_v" , # Grok
226237 ),
227238
228239 # AWQ-activation gate
@@ -243,6 +254,7 @@ class TensorNameMap:
243254 MODEL_TENSOR .FFN_GATE_EXP : (
244255 "layers.{bid}.feed_forward.experts.{xid}.w1" , # mixtral
245256 "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1" , # mixtral
257+ "transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
246258 ),
247259
248260 # Feed-forward down
@@ -270,6 +282,8 @@ class TensorNameMap:
270282 MODEL_TENSOR .FFN_DOWN_EXP : (
271283 "layers.{bid}.feed_forward.experts.{xid}.w2" , # mixtral
272284 "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2" , # mixtral
285+ "transformer.decoder_layer.{bid}.moe.{xid}.linear_1" , # Grok
286+
273287 ),
274288
275289 MODEL_TENSOR .ATTN_Q_NORM : (
@@ -287,8 +301,9 @@ class TensorNameMap:
287301 ),
288302
289303 MODEL_TENSOR .LAYER_OUT_NORM : (
290- "encoder.layer.{bid}.output.LayerNorm" , # bert
291- "encoder.layers.{bid}.norm2" , # nomic-bert
304+ "encoder.layer.{bid}.output.LayerNorm" , # bert
305+ "encoder.layers.{bid}.norm2" , # nomic-bert
306+ "transformer.decoder_layer.{bid}.rms_norm_3" , # Grok
292307 ),
293308
294309 MODEL_TENSOR .SSM_IN : (
0 commit comments