@@ -22,7 +22,6 @@ namespace chatllm::bailing::moe
2222 class ChatHistoryEncoder : public BaseHistoryEncoder
2323 {
2424 public:
25- void append_sys_prompt (std::vector<int > &ids) const override ;
2625 void append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const override ;
2726 void append_user (int round_idx, const std::string &user, std::vector<int > &ids) const override ;
2827 void append_ai_opening (int round_idx, std::vector<int > &ids) const override ;
@@ -33,8 +32,8 @@ namespace chatllm::bailing::moe
3332 class Tokenizer : public BaseTokenizer
3433 {
3534 public:
36- Tokenizer (const Config &config)
37- : BaseTokenizer(config, &_chat_encoder )
35+ Tokenizer (const Config &config, BaseHistoryEncoder *chat_encoder = &_chat_encoder )
36+ : BaseTokenizer(config, chat_encoder )
3837 {
3938 sys_prompt = " You are Ling, an assistant created by inclusionAI" ;
4039 }
@@ -56,12 +55,9 @@ namespace chatllm::bailing::moe
5655 size_t size = tp->Load (buffer, n_vocab);
5756
5857 role_open_token_id = tp->PieceToId (" <role>" );
59- eos_token_id = tp->PieceToId (" <|role_end|>" );
60- mask_token_id = tp->PieceToId (" <|mask|>" );
6158
62- // LlaDA might generate lots of PAD
63- if (mask_token_id >= 0 )
64- terminate_ids.insert (pad_token_id);
59+ if (role_open_token_id >= 0 )
60+ terminate_ids.insert (role_open_token_id);
6561
6662 int t = tp->PieceToId (" <think>" );
6763 if (t >= 0 )
@@ -78,32 +74,18 @@ namespace chatllm::bailing::moe
7874
7975 protected:
8076 int role_open_token_id;
81- public:
82- int mask_token_id;
8377 };
8478
85- void ChatHistoryEncoder::append_sys_prompt (std::vector<int > &ids) const
86- {
87- if (tokenizer->get_system_prompt ().size () > 0 )
88- {
89- tokenizer->encode (" <role>SYSTEM</role>" , ids);
90- tokenizer->encode (tokenizer->get_system_prompt (), ids);
91- ids.push_back (tokenizer->eos_token_id );
92- }
93- }
94-
9579 void ChatHistoryEncoder::append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const
9680 {
9781 append_ai_opening (round_idx, ids);
9882 tokenizer->encode (ai, ids);
99- ids.push_back (tokenizer->eos_token_id );
10083 }
10184
10285 void ChatHistoryEncoder::append_user (int round_idx, const std::string &user, std::vector<int > &ids) const
10386 {
10487 tokenizer->encode (" <role>HUMAN</role>" , ids);
10588 tokenizer->encode (user, ids);
106- ids.push_back (tokenizer->eos_token_id );
10789 }
10890
10991 void ChatHistoryEncoder::append_ai_opening (int round_idx, std::vector<int > &ids) const
@@ -209,9 +191,14 @@ namespace chatllm::bailing::moe2
209191 ggml::tensor *attn_scores_to_probs (ComputeContext *ctx, int hidden_size, const int n_past, const int qlen,
210192 ggml::tensor *attn_scores) override
211193 {
194+ if (nullptr == mask)
195+ {
196+ return QKNormedAttention<RMSNorm, BaseAttention>::attn_scores_to_probs (ctx, hidden_size, n_past, qlen, attn_scores);
197+ }
198+
212199 const int head_size = hidden_size / num_attention_heads;
213200
214- ggml::tensor * sub_mask = mask ? ggml::view_2d (ctx, mask, n_past + qlen, qlen, (n_past + qlen) * ggml::element_size (mask), 0 ) : nullptr ;
201+ ggml::tensor * sub_mask = ggml::view_2d (ctx, mask, n_past + qlen, qlen, (n_past + qlen) * ggml::element_size (mask), 0 );
215202
216203 // attn_probs = soft_max(attn_masked)
217204 ggml::tensor * attn_probs = ggml::soft_max_ext (ctx, attn_scores, sub_mask, 1 .f / sqrtf ((float )head_size), 0 .0f );
@@ -341,7 +328,71 @@ namespace chatllm::bailing::moe2
341328namespace chatllm ::bailing::llada
342329{
343330 typedef moe2::Config Config;
344- typedef moe2::Tokenizer Tokenizer;
331+
332+ class ChatHistoryEncoder : public BaseHistoryEncoder
333+ {
334+ public:
335+ void append_sys_prompt (std::vector<int > &ids) const override ;
336+ void append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const override ;
337+ void append_user (int round_idx, const std::string &user, std::vector<int > &ids) const override ;
338+ void append_ai_opening (int round_idx, std::vector<int > &ids) const override ;
339+ };
340+
341+ static ChatHistoryEncoder _chat_encoder;
342+
343+ class Tokenizer : public moe ::Tokenizer
344+ {
345+ public:
346+ Tokenizer (const Config &config)
347+ : moe::Tokenizer(config, &_chat_encoder)
348+ {
349+ }
350+
351+ size_t load (tokenizer::DataReader *buffer, int n_vocab) override
352+ {
353+ size_t r = moe::Tokenizer::load (buffer, n_vocab);
354+
355+ eos_token_id = tp->PieceToId (" <|role_end|>" );
356+ mask_token_id = tp->PieceToId (" <|mask|>" );
357+
358+ // LlaDA might generate lots of PAD
359+ if (mask_token_id >= 0 )
360+ terminate_ids.insert (pad_token_id);
361+
362+ return r;
363+ }
364+ public:
365+ int mask_token_id;
366+ };
367+
368+ void ChatHistoryEncoder::append_sys_prompt (std::vector<int > &ids) const
369+ {
370+ if (tokenizer->get_system_prompt ().size () > 0 )
371+ {
372+ tokenizer->encode (" <role>SYSTEM</role>" , ids);
373+ tokenizer->encode (tokenizer->get_system_prompt (), ids);
374+ ids.push_back (tokenizer->eos_token_id );
375+ }
376+ }
377+
378+ void ChatHistoryEncoder::append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const
379+ {
380+ append_ai_opening (round_idx, ids);
381+ tokenizer->encode (ai, ids);
382+ ids.push_back (tokenizer->eos_token_id );
383+ }
384+
385+ void ChatHistoryEncoder::append_user (int round_idx, const std::string &user, std::vector<int > &ids) const
386+ {
387+ tokenizer->encode (" <role>HUMAN</role>" , ids);
388+ tokenizer->encode (user, ids);
389+ ids.push_back (tokenizer->eos_token_id );
390+ }
391+
392+ void ChatHistoryEncoder::append_ai_opening (int round_idx, std::vector<int > &ids) const
393+ {
394+ tokenizer->encode (" <role>ASSISTANT</role>" , ids);
395+ }
345396
346397 class Prelude
347398 {
@@ -535,7 +586,7 @@ namespace chatllm::bailing::llada
535586 << " requested max_length (" << gen_config.max_length << " ) is larger than model's max_length ("
536587 << config_.max_length << " )" ;
537588
538- const int mask_id = dynamic_cast <bailing::moe:: Tokenizer *>(tokenizer)->mask_token_id ;
589+ const int mask_id = dynamic_cast <Tokenizer *>(tokenizer)->mask_token_id ;
539590 std::vector<int > num_transfer_tokens_schedule;
540591 for (int i = 0 ; i < steps; i++) num_transfer_tokens_schedule.push_back (block_length / steps);
541592 const int remain = block_length % steps;
0 commit comments