@@ -118,76 +118,97 @@ class MODEL_TENSOR(IntEnum):
118118 MODEL_ARCH .STARCODER : "starcoder" ,
119119}
120120
121- MODEL_TENSOR_NAMES : dict [MODEL_ARCH , dict [MODEL_TENSOR , str ]] = {
122- MODEL_ARCH .LLAMA : {
123- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
124- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125- MODEL_TENSOR .OUTPUT : "output" ,
126- MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
127- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
128- MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
129- MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
130- MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
131- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
132- MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
133- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
134- MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
135- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
136- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
137- },
138- MODEL_ARCH .GPTNEOX : {
139- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
140- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
141- MODEL_TENSOR .OUTPUT : "output" ,
142- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
143- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
144- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
145- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
146- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
147- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
148- },
149- MODEL_ARCH .FALCON : {
150- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
151- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
152- MODEL_TENSOR .OUTPUT : "output" ,
153- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
154- MODEL_TENSOR .ATTN_NORM_2 : "blk.{bid}.attn_norm_2" ,
155- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
156- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
157- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
158- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
159- },
160- MODEL_ARCH .BAICHUAN : {
161- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
162- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
163- MODEL_TENSOR .OUTPUT : "output" ,
164- MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
165- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
166- MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
167- MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
168- MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
169- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
170- MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
171- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
172- MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
173- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
174- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
175- },
176- MODEL_ARCH .STARCODER : {
177- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
178- MODEL_TENSOR .POS_EMBD : "position_embd" ,
179- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
180- MODEL_TENSOR .OUTPUT : "output" ,
181- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
182- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
183- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
184- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
185- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
186- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
187- },
188- MODEL_ARCH .GPT2 : {
121+ TENSOR_NAMES : dict [MODEL_TENSOR , str ] = {
122+ MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
123+ MODEL_TENSOR .POS_EMBD : "position_embd" ,
124+ MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125+ MODEL_TENSOR .OUTPUT : "output" ,
126+ MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
127+
128+ MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
129+ MODEL_TENSOR .ATTN_NORM_2 : "blk.{bid}.attn_norm_2" ,
130+ MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
131+ MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
132+ MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
133+ MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
134+ MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
135+ MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
136+ MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
137+ MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
138+ MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
139+ MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
140+ }
141+
142+ MODEL_TENSORS : dict [MODEL_ARCH , list [MODEL_TENSOR ]] = {
143+ MODEL_ARCH .LLAMA : [
144+ MODEL_TENSOR .TOKEN_EMBD ,
145+ MODEL_TENSOR .OUTPUT_NORM ,
146+ MODEL_TENSOR .OUTPUT ,
147+ MODEL_TENSOR .ROPE_FREQS ,
148+ MODEL_TENSOR .ATTN_NORM ,
149+ MODEL_TENSOR .ATTN_Q ,
150+ MODEL_TENSOR .ATTN_K ,
151+ MODEL_TENSOR .ATTN_V ,
152+ MODEL_TENSOR .ATTN_OUT ,
153+ MODEL_TENSOR .ATTN_ROT_EMBD ,
154+ MODEL_TENSOR .FFN_NORM ,
155+ MODEL_TENSOR .FFN_GATE ,
156+ MODEL_TENSOR .FFN_DOWN ,
157+ MODEL_TENSOR .FFN_UP ,
158+ ],
159+ MODEL_ARCH .GPTNEOX : [
160+ MODEL_TENSOR .TOKEN_EMBD ,
161+ MODEL_TENSOR .OUTPUT_NORM ,
162+ MODEL_TENSOR .OUTPUT ,
163+ MODEL_TENSOR .ATTN_NORM ,
164+ MODEL_TENSOR .ATTN_QKV ,
165+ MODEL_TENSOR .ATTN_OUT ,
166+ MODEL_TENSOR .FFN_NORM ,
167+ MODEL_TENSOR .FFN_DOWN ,
168+ MODEL_TENSOR .FFN_UP ,
169+ ],
170+ MODEL_ARCH .FALCON : [
171+ MODEL_TENSOR .TOKEN_EMBD ,
172+ MODEL_TENSOR .OUTPUT_NORM ,
173+ MODEL_TENSOR .OUTPUT ,
174+ MODEL_TENSOR .ATTN_NORM ,
175+ MODEL_TENSOR .ATTN_NORM_2 ,
176+ MODEL_TENSOR .ATTN_QKV ,
177+ MODEL_TENSOR .ATTN_OUT ,
178+ MODEL_TENSOR .FFN_DOWN ,
179+ MODEL_TENSOR .FFN_UP ,
180+ ],
181+ MODEL_ARCH .BAICHUAN : [
182+ MODEL_TENSOR .TOKEN_EMBD ,
183+ MODEL_TENSOR .OUTPUT_NORM ,
184+ MODEL_TENSOR .OUTPUT ,
185+ MODEL_TENSOR .ROPE_FREQS ,
186+ MODEL_TENSOR .ATTN_NORM ,
187+ MODEL_TENSOR .ATTN_Q ,
188+ MODEL_TENSOR .ATTN_K ,
189+ MODEL_TENSOR .ATTN_V ,
190+ MODEL_TENSOR .ATTN_OUT ,
191+ MODEL_TENSOR .ATTN_ROT_EMBD ,
192+ MODEL_TENSOR .FFN_NORM ,
193+ MODEL_TENSOR .FFN_GATE ,
194+ MODEL_TENSOR .FFN_DOWN ,
195+ MODEL_TENSOR .FFN_UP ,
196+ ],
197+ MODEL_ARCH .STARCODER : [
198+ MODEL_TENSOR .TOKEN_EMBD ,
199+ MODEL_TENSOR .POS_EMBD ,
200+ MODEL_TENSOR .OUTPUT_NORM ,
201+ MODEL_TENSOR .OUTPUT ,
202+ MODEL_TENSOR .ATTN_NORM ,
203+ MODEL_TENSOR .ATTN_QKV ,
204+ MODEL_TENSOR .ATTN_OUT ,
205+ MODEL_TENSOR .FFN_NORM ,
206+ MODEL_TENSOR .FFN_DOWN ,
207+ MODEL_TENSOR .FFN_UP ,
208+ ],
209+ MODEL_ARCH .GPT2 : [
189210 # TODO
190- } ,
211+ ] ,
191212 # TODO
192213}
193214
@@ -338,28 +359,24 @@ class TensorNameMap:
338359
339360 mapping : dict [str , tuple [MODEL_TENSOR , str ]]
340361
341- tensor_names : dict [MODEL_TENSOR , str ]
342-
343362 def __init__ (self , arch : MODEL_ARCH , n_blocks : int ):
344- mapping = self .mapping = {}
345- tensor_names = self .tensor_names = MODEL_TENSOR_NAMES [arch ]
363+ self .mapping = {}
346364 for tensor , keys in self .mappings_cfg .items ():
347- tensor_name = tensor_names .get (tensor )
348- if tensor_name is None :
365+ if tensor not in MODEL_TENSORS [arch ]:
349366 continue
350- mapping [tensor_name ] = (tensor , tensor_name )
367+ tensor_name = TENSOR_NAMES [tensor ]
368+ self .mapping [tensor_name ] = (tensor , tensor_name )
351369 for key in keys :
352- mapping [key ] = (tensor , tensor_name )
370+ self . mapping [key ] = (tensor , tensor_name )
353371 for bid in range (n_blocks ):
354372 for tensor , keys in self .block_mappings_cfg .items ():
355- tensor_name = tensor_names .get (tensor )
356- if tensor_name is None :
373+ if tensor not in MODEL_TENSORS [arch ]:
357374 continue
358- tensor_name = tensor_name .format (bid = bid )
359- mapping [tensor_name ] = (tensor , tensor_name )
375+ tensor_name = TENSOR_NAMES [ tensor ] .format (bid = bid )
376+ self . mapping [tensor_name ] = (tensor , tensor_name )
360377 for key in keys :
361378 key = key .format (bid = bid )
362- mapping [key ] = (tensor , tensor_name )
379+ self . mapping [key ] = (tensor , tensor_name )
363380
364381 def get_type_and_name (self , key : str , try_suffixes : Sequence [str ] = ()) -> tuple [MODEL_TENSOR , str ] | None :
365382 result = self .mapping .get (key )
@@ -800,22 +817,25 @@ class SpecialVocab:
800817 special_token_types : tuple [str , ...] = ('bos' , 'eos' , 'unk' , 'sep' , 'pad' )
801818 special_token_ids : dict [str , int ] = {}
802819
803- def __init__ (self , path : Path , load_merges : bool = False , special_token_types : tuple [str , ...] | None = None ):
820+ def __init__ (
821+ self , path : str | os .PathLike [str ], load_merges : bool = False ,
822+ special_token_types : tuple [str , ...] | None = None ,
823+ ):
804824 self .special_token_ids = {}
805825 self .load_merges = load_merges
806826 if special_token_types is not None :
807827 self .special_token_types = special_token_types
808- self .load ( path )
828+ self ._load ( Path ( path ) )
809829
810- def load (self , path : Path ):
811- if not self .try_load_from_tokenizer_json (path ):
812- self .try_load_from_config_json (path )
830+ def _load (self , path : Path ) -> None :
831+ if not self ._try_load_from_tokenizer_json (path ):
832+ self ._try_load_from_config_json (path )
813833
814- def try_load_from_tokenizer_json (self , path : Path ) -> bool :
834+ def _try_load_from_tokenizer_json (self , path : Path ) -> bool :
815835 tokenizer_file = path / 'tokenizer.json'
816836 if not tokenizer_file .is_file ():
817837 return False
818- with open (tokenizer_file , 'r' , encoding = 'utf-8' ) as f :
838+ with open (tokenizer_file , encoding = 'utf-8' ) as f :
819839 tokenizer = json .load (f )
820840 if self .load_merges :
821841 merges = tokenizer .get ('model' , {}).get ('merges' )
@@ -825,7 +845,7 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
825845 added_tokens = tokenizer .get ('added_tokens' )
826846 if added_tokens is None or not tokenizer_config_file .is_file ():
827847 return True
828- with open (tokenizer_config_file , 'r' , encoding = 'utf-8' ) as f :
848+ with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
829849 tokenizer_config = json .load (f )
830850 for typ in self .special_token_types :
831851 entry = tokenizer_config .get (f'{ typ } _token' )
@@ -844,19 +864,19 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
844864 break
845865 return True
846866
847- def try_load_from_config_json (self , path : Path ) -> bool :
867+ def _try_load_from_config_json (self , path : Path ) -> bool :
848868 config_file = path / 'config.json'
849869 if not config_file .is_file ():
850870 return False
851- with open (config_file , 'r' , encoding = 'utf-8' ) as f :
871+ with open (config_file , encoding = 'utf-8' ) as f :
852872 config = json .load (f )
853873 for typ in self .special_token_types :
854874 maybe_token_id = config .get (f'{ typ } _token_id' )
855875 if isinstance (maybe_token_id , int ) and maybe_token_id >= 0 :
856876 self .special_token_ids [typ ] = maybe_token_id
857877 return True
858878
859- def add_to_gguf (self , gw : GGUFWriter ):
879+ def add_to_gguf (self , gw : GGUFWriter ) -> None :
860880 if len (self .merges ) > 0 :
861881 print (f'gguf: Adding { len (self .merges )} merge(s).' )
862882 gw .add_token_merges (self .merges )
@@ -868,8 +888,8 @@ def add_to_gguf(self, gw: GGUFWriter):
868888 print (f'gguf: Setting special token type { typ } to { tokid } ' )
869889 handler (tokid )
870890
871- def __repr__ (self ):
872- return f'<SpecialVocab with { len (self .merges )} merges and special tokens { self .special_token_ids if self . special_token_ids else "unset" } >'
891+ def __repr__ (self ) -> str :
892+ return f'<SpecialVocab with { len (self .merges )} merges and special tokens { self .special_token_ids or "unset" } >'
873893
874894
875895# Example usage:
0 commit comments