@@ -213,10 +213,10 @@ def __init__(self,
213213 cache_config : Optional [CacheConfig ] = None ,
214214 quant_config : Optional [QuantizationConfig ] = None ):
215215 super ().__init__ ()
216- self .vocab_size = config .vocab_size
216+ self .vocab_size = config .text_config . vocab_size
217217
218- self .embed_tokens = VocabParallelEmbedding (config . vocab_size ,
219- config .hidden_size )
218+ self .embed_tokens = VocabParallelEmbedding (
219+ config . text_config . vocab_size , config .hidden_size )
220220 self .layers = nn .ModuleList ([
221221 PersimmonDecoderLayer (config ,
222222 cache_config = cache_config ,
@@ -257,14 +257,14 @@ def __init__(self,
257257 quant_config : Optional [QuantizationConfig ] = None ):
258258 super ().__init__ ()
259259 self .config = config
260- self .vocab_size = config .vocab_size
260+ self .vocab_size = config .text_config . vocab_size
261261 self .model = PersimmonModel (config ,
262262 cache_config = cache_config ,
263263 quant_config = quant_config )
264- self .lm_head = ParallelLMHead (config .vocab_size ,
264+ self .lm_head = ParallelLMHead (config .text_config . vocab_size ,
265265 config .hidden_size ,
266266 bias = False )
267- self .logits_processor = LogitsProcessor (config .vocab_size )
267+ self .logits_processor = LogitsProcessor (config .text_config . vocab_size )
268268 self .sampler = Sampler ()
269269
270270 def forward (
0 commit comments