@@ -85,10 +85,12 @@ class MODEL_ARCH(IntEnum):
8585    GPTNEOX        : int  =  auto ()
8686    MPT            : int  =  auto ()
8787    STARCODER      : int  =  auto ()
88+     BERT           : int  =  auto ()
8889
8990
9091class  MODEL_TENSOR (IntEnum ):
9192    TOKEN_EMBD    : int  =  auto ()
93+     TOKEN_TYPES   : int  =  auto ()
9294    POS_EMBD      : int  =  auto ()
9395    OUTPUT        : int  =  auto ()
9496    OUTPUT_NORM   : int  =  auto ()
@@ -116,10 +118,12 @@ class MODEL_TENSOR(IntEnum):
116118    MODEL_ARCH .GPTNEOX :        "gptneox" ,
117119    MODEL_ARCH .MPT :            "mpt" ,
118120    MODEL_ARCH .STARCODER :      "starcoder" ,
121+     MODEL_ARCH .BERT :           "bert" ,
119122}
120123
121124TENSOR_NAMES : dict [MODEL_TENSOR , str ] =  {
122125    MODEL_TENSOR .TOKEN_EMBD :    "token_embd" ,
126+     MODEL_TENSOR .TOKEN_TYPES :   "token_types" ,
123127    MODEL_TENSOR .POS_EMBD :      "position_embd" ,
124128    MODEL_TENSOR .OUTPUT_NORM :   "output_norm" ,
125129    MODEL_TENSOR .OUTPUT :        "output" ,
@@ -206,6 +210,43 @@ class MODEL_TENSOR(IntEnum):
206210        MODEL_TENSOR .FFN_DOWN ,
207211        MODEL_TENSOR .FFN_UP ,
208212    ],
213+     MODEL_ARCH .BERT : [
214+         MODEL_TENSOR .TOKEN_EMBD ,
215+         MODEL_TENSOR .TOKEN_TYPES ,
216+         MODEL_TENSOR .POS_EMBD ,
217+         MODEL_TENSOR .OUTPUT_NORM ,
218+         MODEL_TENSOR .ATTN_NORM ,
219+         MODEL_TENSOR .ATTN_Q ,
220+         MODEL_TENSOR .ATTN_K ,
221+         MODEL_TENSOR .ATTN_V ,
222+         MODEL_TENSOR .ATTN_OUT ,
223+         MODEL_TENSOR .FFN_NORM ,
224+         MODEL_TENSOR .FFN_DOWN ,
225+         MODEL_TENSOR .FFN_UP ,
226+     ],
227+     MODEL_ARCH .MPT : [
228+         MODEL_TENSOR .TOKEN_EMBD ,
229+         MODEL_TENSOR .OUTPUT_NORM ,
230+         MODEL_TENSOR .OUTPUT ,
231+         MODEL_TENSOR .ATTN_NORM ,
232+         MODEL_TENSOR .ATTN_QKV ,
233+         MODEL_TENSOR .ATTN_OUT ,
234+         MODEL_TENSOR .FFN_NORM ,
235+         MODEL_TENSOR .FFN_DOWN ,
236+         MODEL_TENSOR .FFN_UP ,
237+     ],
238+     MODEL_ARCH .GPTJ : [
239+         MODEL_TENSOR .TOKEN_EMBD ,
240+         MODEL_TENSOR .OUTPUT_NORM ,
241+         MODEL_TENSOR .OUTPUT ,
242+         MODEL_TENSOR .ATTN_NORM ,
243+         MODEL_TENSOR .ATTN_Q ,
244+         MODEL_TENSOR .ATTN_K ,
245+         MODEL_TENSOR .ATTN_V ,
246+         MODEL_TENSOR .ATTN_OUT ,
247+         MODEL_TENSOR .FFN_DOWN ,
248+         MODEL_TENSOR .FFN_UP ,
249+     ],
209250    MODEL_ARCH .GPT2 : [
210251        # TODO 
211252    ],
@@ -229,31 +270,40 @@ class TensorNameMap:
229270    mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] =  {
230271        # Token embeddings 
231272        MODEL_TENSOR .TOKEN_EMBD : (
232-             "gpt_neox.embed_in" ,           # gptneox 
233-             "transformer.wte" ,             # gpt2 mpt 
234-             "transformer.word_embeddings" , # falcon 
235-             "model.embed_tokens" ,          # llama-hf 
236-             "tok_embeddings" ,              # llama-pth 
273+             "gpt_neox.embed_in" ,            # gptneox 
274+             "transformer.wte" ,              # gpt2 gpt-j mpt 
275+             "transformer.word_embeddings" ,  # falcon 
276+             "model.embed_tokens" ,           # llama-hf 
277+             "tok_embeddings" ,               # llama-pth 
278+             "embeddings.word_embeddings" ,   # bert 
279+         ),
280+ 
281+         # Token type embeddings 
282+         MODEL_TENSOR .TOKEN_TYPES : (
283+             "embeddings.token_type_embeddings" ,  # bert 
237284        ),
238285
239286        # Position embeddings 
240287        MODEL_TENSOR .POS_EMBD : (
241-             "transformer.wpe" , # gpt2 
288+             "transformer.wpe" ,                 # gpt2 
289+             "embeddings.position_embeddings" ,  # bert 
242290        ),
243291
244292        # Output 
245293        MODEL_TENSOR .OUTPUT : (
246-             "embed_out" , # gptneox 
247-             "lm_head" ,   # gpt2 mpt falcon llama-hf baichuan 
248-             "output" ,    # llama-pth 
294+             "embed_out" ,   # gptneox 
295+             "lm_head" ,     # gpt2 gpt-j  mpt falcon llama-hf baichuan 
296+             "output" ,      # llama-pth 
249297        ),
250298
251299        # Output norm 
252300        MODEL_TENSOR .OUTPUT_NORM : (
253-             "gpt_neox.final_layer_norm" , # gptneox 
254-             "transformer.ln_f" ,          # gpt2 falcon 
255-             "model.norm" ,                # llama-hf baichuan 
256-             "norm" ,                      # llama-pth 
301+             "gpt_neox.final_layer_norm" ,  # gptneox 
302+             "transformer.ln_f" ,           # gpt2 gpt-j falcon 
303+             "model.norm" ,                 # llama-hf baichuan 
304+             "norm" ,                       # llama-pth 
305+             "embeddings.LayerNorm" ,       # bert 
306+             "transformer.norm_f" ,         # mpt 
257307        ),
258308
259309        # Rope frequencies 
@@ -265,13 +315,14 @@ class TensorNameMap:
265315    block_mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] =  {
266316        # Attention norm 
267317        MODEL_TENSOR .ATTN_NORM : (
268-             "gpt_neox.layers.{bid}.input_layernorm" , # gptneox 
269-             "transformer.h.{bid}.ln_1" ,              # gpt2 
270-             "transformer.blocks.{bid}.norm_1" ,       # mpt 
271-             "transformer.h.{bid}.input_layernorm" ,   # falcon7b 
272-             "transformer.h.{bid}.ln_mlp" ,            # falcon40b 
273-             "model.layers.{bid}.input_layernorm" ,    # llama-hf 
274-             "layers.{bid}.attention_norm" ,           # llama-pth 
318+             "gpt_neox.layers.{bid}.input_layernorm" ,           # gptneox 
319+             "transformer.h.{bid}.ln_1" ,                        # gpt2 gpt-j 
320+             "transformer.blocks.{bid}.norm_1" ,                 # mpt 
321+             "transformer.h.{bid}.input_layernorm" ,             # falcon7b 
322+             "transformer.h.{bid}.ln_mlp" ,                      # falcon40b 
323+             "model.layers.{bid}.input_layernorm" ,              # llama-hf 
324+             "layers.{bid}.attention_norm" ,                     # llama-pth 
325+             "encoder.layer.{bid}.attention.output.LayerNorm" ,  # bert 
275326        ),
276327
277328        # Attention norm 2 
@@ -281,38 +332,46 @@ class TensorNameMap:
281332
282333        # Attention query-key-value 
283334        MODEL_TENSOR .ATTN_QKV : (
284-             "gpt_neox.layers.{bid}.attention.query_key_value" ,    # gptneox 
285-             "transformer.h.{bid}.attn.c_attn" ,                    # gpt2 
286-             "transformer.blocks.{bid}.attn.Wqkv" ,                 # mpt 
287-             "transformer.h.{bid}.self_attention.query_key_value" , # falcon 
335+             "gpt_neox.layers.{bid}.attention.query_key_value" ,      # gptneox 
336+             "transformer.h.{bid}.attn.c_attn" ,                      # gpt2 
337+             "transformer.blocks.{bid}.attn.Wqkv" ,                   # mpt 
338+             "transformer.h.{bid}.self_attention.query_key_value" ,   # falcon 
288339        ),
289340
290341        # Attention query 
291342        MODEL_TENSOR .ATTN_Q : (
292-             "model.layers.{bid}.self_attn.q_proj" , # llama-hf 
293-             "layers.{bid}.attention.wq" ,           # llama-pth 
343+             "model.layers.{bid}.self_attn.q_proj" ,       # llama-hf 
344+             "layers.{bid}.attention.wq" ,                 # llama-pth 
345+             "encoder.layer.{bid}.attention.self.query" ,  # bert 
346+             "transformer.h.{bid}.attn.q_proj" ,           # gpt-j 
294347        ),
295348
296349        # Attention key 
297350        MODEL_TENSOR .ATTN_K : (
298-             "model.layers.{bid}.self_attn.k_proj" , # llama-hf 
299-             "layers.{bid}.attention.wk" ,           # llama-pth 
351+             "model.layers.{bid}.self_attn.k_proj" ,     # llama-hf 
352+             "layers.{bid}.attention.wk" ,               # llama-pth 
353+             "encoder.layer.{bid}.attention.self.key" ,  # bert 
354+             "transformer.h.{bid}.attn.k_proj" ,         # gpt-j 
300355        ),
301356
302357        # Attention value 
303358        MODEL_TENSOR .ATTN_V : (
304-             "model.layers.{bid}.self_attn.v_proj" , # llama-hf 
305-             "layers.{bid}.attention.wv" ,           # llama-pth 
359+             "model.layers.{bid}.self_attn.v_proj" ,       # llama-hf 
360+             "layers.{bid}.attention.wv" ,                 # llama-pth 
361+             "encoder.layer.{bid}.attention.self.value" ,  # bert 
362+             "transformer.h.{bid}.attn.v_proj" ,           # gpt-j 
306363        ),
307364
308365        # Attention output 
309366        MODEL_TENSOR .ATTN_OUT : (
310-             "gpt_neox.layers.{bid}.attention.dense" ,    # gptneox 
311-             "transformer.h.{bid}.attn.c_proj" ,          # gpt2 
312-             "transformer.blocks.{bid}.attn.out_proj" ,   # mpt 
313-             "transformer.h.{bid}.self_attention.dense" , # falcon 
314-             "model.layers.{bid}.self_attn.o_proj" ,      # llama-hf 
315-             "layers.{bid}.attention.wo" ,                # llama-pth 
367+             "gpt_neox.layers.{bid}.attention.dense" ,       # gptneox 
368+             "transformer.h.{bid}.attn.c_proj" ,             # gpt2 
369+             "transformer.blocks.{bid}.attn.out_proj" ,      # mpt 
370+             "transformer.h.{bid}.self_attention.dense" ,    # falcon 
371+             "model.layers.{bid}.self_attn.o_proj" ,         # llama-hf 
372+             "layers.{bid}.attention.wo" ,                   # llama-pth 
373+             "encoder.layer.{bid}.attention.output.dense" ,  # bert 
374+             "transformer.h.{bid}.attn.out_proj" ,           # gpt-j 
316375        ),
317376
318377        # Rotary embeddings 
@@ -323,21 +382,24 @@ class TensorNameMap:
323382
324383        # Feed-forward norm 
325384        MODEL_TENSOR .FFN_NORM : (
326-             "gpt_neox.layers.{bid}.post_attention_layernorm" , # gptneox 
327-             "transformer.h.{bid}.ln_2" ,                       # gpt2 
328-             "transformer.blocks.{bid}.norm_2" ,                # mpt 
329-             "model.layers.{bid}.post_attention_layernorm" ,    # llama-hf 
330-             "layers.{bid}.ffn_norm" ,                          # llama-pth 
385+             "gpt_neox.layers.{bid}.post_attention_layernorm" ,  # gptneox 
386+             "transformer.h.{bid}.ln_2" ,                        # gpt2 
387+             "transformer.blocks.{bid}.norm_2" ,                 # mpt 
388+             "model.layers.{bid}.post_attention_layernorm" ,     # llama-hf 
389+             "layers.{bid}.ffn_norm" ,                           # llama-pth 
390+             "encoder.layer.{bid}.output.LayerNorm" ,            # bert 
331391        ),
332392
333393        # Feed-forward up 
334394        MODEL_TENSOR .FFN_UP : (
335-             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox 
336-             "transformer.h.{bid}.mlp.c_fc" ,            # gpt2 
337-             "transformer.blocks.{bid}.ffn.up_proj" ,    # mpt 
338-             "transformer.h.{bid}.mlp.dense_h_to_4h" ,   # falcon 
339-             "model.layers.{bid}.mlp.up_proj" ,          # llama-hf 
340-             "layers.{bid}.feed_forward.w3" ,            # llama-pth 
395+             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" ,  # gptneox 
396+             "transformer.h.{bid}.mlp.c_fc" ,             # gpt2 
397+             "transformer.blocks.{bid}.ffn.up_proj" ,     # mpt 
398+             "transformer.h.{bid}.mlp.dense_h_to_4h" ,    # falcon 
399+             "model.layers.{bid}.mlp.up_proj" ,           # llama-hf 
400+             "layers.{bid}.feed_forward.w3" ,             # llama-pth 
401+             "encoder.layer.{bid}.intermediate.dense" ,   # bert 
402+             "transformer.h.{bid}.mlp.fc_in" ,            # gpt-j 
341403        ),
342404
343405        # Feed-forward gate 
@@ -348,12 +410,14 @@ class TensorNameMap:
348410
349411        # Feed-forward down 
350412        MODEL_TENSOR .FFN_DOWN : (
351-             "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" , # gptneox 
352-             "transformer.h.{bid}.mlp.c_proj" ,          # gpt2 
353-             "transformer.blocks.{bid}.ffn.down_proj" ,  # mpt 
354-             "transformer.h.{bid}.mlp.dense_4h_to_h" ,   # falcon 
355-             "model.layers.{bid}.mlp.down_proj" ,        # llama-hf 
356-             "layers.{bid}.feed_forward.w2" ,            # llama-pth 
413+             "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" ,  # gptneox 
414+             "transformer.h.{bid}.mlp.c_proj" ,           # gpt2 
415+             "transformer.blocks.{bid}.ffn.down_proj" ,   # mpt 
416+             "transformer.h.{bid}.mlp.dense_4h_to_h" ,    # falcon 
417+             "model.layers.{bid}.mlp.down_proj" ,         # llama-hf 
418+             "layers.{bid}.feed_forward.w2" ,             # llama-pth 
419+             "encoder.layer.{bid}.output.dense" ,         # bert 
420+             "transformer.h.{bid}.mlp.fc_out" ,           # gpt-j 
357421        ),
358422    }
359423
0 commit comments