@@ -11,11 +11,72 @@ namespace chatllm::ernie::dense
1111 void append_ai_opening (int round_idx, std::vector<int > &ids) const override ;
1212 };
1313
14- static ChatHistoryEncoder _chat_encoder;
14+ class ChatHistoryThinkingEncoder : public BaseHistoryEncoder
15+ {
16+ public:
17+ void append_sys_prompt (std::vector<int > &ids) const override ;
18+ void append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const override ;
19+ void append_user (int round_idx, const std::string &user, std::vector<int > &ids) const override ;
20+ void append_ai_opening (int round_idx, std::vector<int > &ids) const override ;
21+ void append_user_opening (int round_idx, std::vector<int > &ids) const override ;
22+ };
23+
24+ static ChatHistoryEncoder _chat_encoder;
25+ static ChatHistoryThinkingEncoder _chat_thinking_encoder;
1526
1627 Tokenizer::Tokenizer (const BaseConfig &config)
17- : chatllm::llama::v2::Tokenizer(config, &_chat_encoder)
18- {}
28+ : chatllm::llama::v2::Tokenizer(config, &_chat_encoder),
29+ im_start_token_id (-1 ), im_end_token_id(-1 ),
30+ nl_token_id(-1 ), think_start_token_id(-1 ),
31+ think_end_token_id(-1 )
32+ {
33+ sys_prompt = " " ;
34+ }
35+
36+ size_t Tokenizer::load (tokenizer::DataReader *buffer, int n_vocab)
37+ {
38+ size_t size = chatllm::llama::v2::Tokenizer::load (buffer, n_vocab);
39+ im_start_token_id = tp->PieceToId (" <|im_start|>" );
40+ im_end_token_id = tp->PieceToId (" <|im_end|>" );
41+ std::vector<int > ids;
42+ tp->Encode (" \n " , &ids);
43+ nl_token_id = ids[0 ];
44+
45+ think_start_token_id = tp->PieceToId (" <think>" );
46+ think_end_token_id = tp->PieceToId (" </think>" );
47+ if (im_end_token_id >= 0 )
48+ terminate_ids.emplace (im_end_token_id);
49+ return size;
50+ }
51+
52+ void Tokenizer::encode_role (const std::string &role, const std::string &text, std::vector<int > &ids) const
53+ {
54+ ids.push_back (im_start_token_id);
55+ BaseTokenizer::encode (role, ids);
56+ ids.push_back (nl_token_id);
57+ BaseTokenizer::encode (text, ids);
58+ ids.push_back (im_end_token_id);
59+ ids.push_back (nl_token_id);
60+ ids.push_back (nl_token_id);
61+ }
62+
63+ void Tokenizer::encode_role (const std::string &role, std::vector<int > &ids) const
64+ {
65+ ids.push_back (im_start_token_id);
66+ BaseTokenizer::encode (role, ids);
67+ }
68+
69+ bool Tokenizer::load_config (const json::JSON &config)
70+ {
71+ auto cfg = config[" tokenizer_config.json" ];
72+ std::string s = cfg[" chat_template" ].ToString ();
73+ if (s.find (" think_mode=True" ) != std::string::npos)
74+ {
75+ set_chat_encoder (&_chat_thinking_encoder);
76+ }
77+
78+ return true ;
79+ }
1980
2081 void ChatHistoryEncoder::append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const
2182 {
@@ -54,6 +115,46 @@ namespace chatllm::ernie::dense
54115 tok->encode (" Assistant: " , ids);
55116 }
56117
118+ void ChatHistoryThinkingEncoder::append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const
119+ {
120+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
121+ tok->encode_role (" assistant" , ai, ids);
122+ }
123+
124+ void ChatHistoryThinkingEncoder::append_sys_prompt (std::vector<int > &ids) const
125+ {
126+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
127+ std::ostringstream oss_prompt;
128+
129+ ids.push_back (tok->bos_token_id );
130+ if (tok->get_system_prompt ().size () > 0 )
131+ {
132+ oss_prompt << " <system_setting>\n " << tok->get_system_prompt () << " \n </system_setting>\n\n " ;
133+ }
134+ oss_prompt << " <global_setting>\n "
135+ << " think_mode=True\n "
136+ << " </global_setting>" ;
137+ tok->encode_role (" system" , oss_prompt.str (), ids);
138+ }
139+
140+ void ChatHistoryThinkingEncoder::append_user (int round_idx, const std::string &user, std::vector<int > &ids) const
141+ {
142+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
143+ tok->encode_role (" user" , user, ids);
144+ }
145+
146+ void ChatHistoryThinkingEncoder::append_ai_opening (int round_idx, std::vector<int > &ids) const
147+ {
148+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
149+ tok->encode_role (" assistant" , ids);
150+ }
151+
152+ void ChatHistoryThinkingEncoder::append_user_opening (int round_idx, std::vector<int > &ids) const
153+ {
154+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
155+ tok->encode_role (" user" , ids);
156+ }
157+
57158 ConditionalGeneration::ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type)
58159 : chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>(config, runtime_config, type,
59160 config.num_key_value_heads, config.head_dim, config.max_length, 12 , config.tie_word_embeddings != 0 )
@@ -196,4 +297,10 @@ namespace chatllm::ernie::moe
196297
197298 ModelProxy::load (loader);
198299 }
300+ }
301+
302+ namespace chatllm
303+ {
304+ REGISTER_MODEL_LOADER (ERNIE_DENSE, ernie::dense, 1 );
305+ REGISTER_MODEL_LOADER (ERNIE_MOE, ernie::moe, 1 );
199306}
0 commit comments