Skip to content

Commit 43a84eb

Browse files
committed
support ernie-thinking
1 parent 83ae708 commit 43a84eb

File tree

4 files changed

+125
-12
lines changed

4 files changed

+125
-12
lines changed

docs/models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@
6464
Two optimization modes are defined: speed (default) and memory. See `BaseMLAttention`.
6565

6666
* ERNIE (`Ernie4_5_ForCausalLM`, `Ernie4_5_MoeForCausalLM`)
67-
* [x] [0.3B](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/c163aa422d265f995b024d1322d91c4e3cb52ec8), [A3B](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT/tree/b24b8917f5379129992dad46c279683c7b845c96)
67+
* [x] Non-thinking: [0.3B](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/c163aa422d265f995b024d1322d91c4e3cb52ec8), [A3B](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT/tree/b24b8917f5379129992dad46c279683c7b845c96)
68+
* [x] Thinking: [A3B](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking/tree/78d7a200cddb8132b074adffcd5aa2ef3361b0ae)
6869

6970
* EXAONE (`ExaoneForCausalLM`)
7071
* [x] v3.5: [Instruct-2.4B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct), [Instruct-7.8B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct), [Instruct-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-32B-Instruct)

models/ernie.cpp

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,72 @@ namespace chatllm::ernie::dense
1111
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
1212
};
1313

14-
static ChatHistoryEncoder _chat_encoder;
14+
class ChatHistoryThinkingEncoder : public BaseHistoryEncoder
15+
{
16+
public:
17+
void append_sys_prompt(std::vector<int> &ids) const override;
18+
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
19+
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
20+
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
21+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
22+
};
23+
24+
static ChatHistoryEncoder _chat_encoder;
25+
static ChatHistoryThinkingEncoder _chat_thinking_encoder;
1526

1627
Tokenizer::Tokenizer(const BaseConfig &config)
17-
: chatllm::llama::v2::Tokenizer(config, &_chat_encoder)
18-
{}
28+
: chatllm::llama::v2::Tokenizer(config, &_chat_encoder),
29+
im_start_token_id(-1), im_end_token_id(-1),
30+
nl_token_id(-1), think_start_token_id(-1),
31+
think_end_token_id(-1)
32+
{
33+
sys_prompt = "";
34+
}
35+
36+
size_t Tokenizer::load(tokenizer::DataReader *buffer, int n_vocab)
37+
{
38+
size_t size = chatllm::llama::v2::Tokenizer::load(buffer, n_vocab);
39+
im_start_token_id = tp->PieceToId("<|im_start|>");
40+
im_end_token_id = tp->PieceToId("<|im_end|>");
41+
std::vector<int> ids;
42+
tp->Encode("\n", &ids);
43+
nl_token_id = ids[0];
44+
45+
think_start_token_id = tp->PieceToId("<think>");
46+
think_end_token_id = tp->PieceToId("</think>");
47+
if (im_end_token_id >= 0)
48+
terminate_ids.emplace(im_end_token_id);
49+
return size;
50+
}
51+
52+
void Tokenizer::encode_role(const std::string &role, const std::string &text, std::vector<int> &ids) const
53+
{
54+
ids.push_back(im_start_token_id);
55+
BaseTokenizer::encode(role, ids);
56+
ids.push_back(nl_token_id);
57+
BaseTokenizer::encode(text, ids);
58+
ids.push_back(im_end_token_id);
59+
ids.push_back(nl_token_id);
60+
ids.push_back(nl_token_id);
61+
}
62+
63+
void Tokenizer::encode_role(const std::string &role, std::vector<int> &ids) const
64+
{
65+
ids.push_back(im_start_token_id);
66+
BaseTokenizer::encode(role, ids);
67+
}
68+
69+
bool Tokenizer::load_config(const json::JSON &config)
70+
{
71+
auto cfg = config["tokenizer_config.json"];
72+
std::string s = cfg["chat_template"].ToString();
73+
if (s.find("think_mode=True") != std::string::npos)
74+
{
75+
set_chat_encoder(&_chat_thinking_encoder);
76+
}
77+
78+
return true;
79+
}
1980

2081
void ChatHistoryEncoder::append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const
2182
{
@@ -54,6 +115,46 @@ namespace chatllm::ernie::dense
54115
tok->encode("Assistant: ", ids);
55116
}
56117

118+
void ChatHistoryThinkingEncoder::append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const
119+
{
120+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
121+
tok->encode_role("assistant", ai, ids);
122+
}
123+
124+
void ChatHistoryThinkingEncoder::append_sys_prompt(std::vector<int> &ids) const
125+
{
126+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
127+
std::ostringstream oss_prompt;
128+
129+
ids.push_back(tok->bos_token_id);
130+
if (tok->get_system_prompt().size() > 0)
131+
{
132+
oss_prompt << "<system_setting>\n" << tok->get_system_prompt() << "\n</system_setting>\n\n";
133+
}
134+
oss_prompt << "<global_setting>\n"
135+
<< "think_mode=True\n"
136+
<< "</global_setting>";
137+
tok->encode_role("system", oss_prompt.str(), ids);
138+
}
139+
140+
void ChatHistoryThinkingEncoder::append_user(int round_idx, const std::string &user, std::vector<int> &ids) const
141+
{
142+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
143+
tok->encode_role("user", user, ids);
144+
}
145+
146+
void ChatHistoryThinkingEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
147+
{
148+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
149+
tok->encode_role("assistant", ids);
150+
}
151+
152+
void ChatHistoryThinkingEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
153+
{
154+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
155+
tok->encode_role("user", ids);
156+
}
157+
57158
ConditionalGeneration::ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type)
58159
: chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>(config, runtime_config, type,
59160
config.num_key_value_heads, config.head_dim, config.max_length, 12, config.tie_word_embeddings != 0)
@@ -196,4 +297,10 @@ namespace chatllm::ernie::moe
196297

197298
ModelProxy::load(loader);
198299
}
300+
}
301+
302+
namespace chatllm
303+
{
304+
REGISTER_MODEL_LOADER(ERNIE_DENSE, ernie::dense, 1);
305+
REGISTER_MODEL_LOADER(ERNIE_MOE, ernie::moe, 1);
199306
}

models/ernie.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ namespace chatllm::ernie::dense
1919
{
2020
public:
2121
Tokenizer(const BaseConfig &config);
22+
size_t load(tokenizer::DataReader *buffer, int n_vocab) override;
23+
void encode_role(const std::string &role, const std::string &text, std::vector<int> &ids) const;
24+
void encode_role(const std::string &role, std::vector<int> &ids) const;
25+
bool load_config(const json::JSON &config) override;
26+
public:
27+
int im_start_token_id;
28+
int im_end_token_id;
29+
int nl_token_id;
30+
int think_start_token_id;
31+
int think_end_token_id;
2232
};
2333

2434
class ConditionalGeneration : public chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>
@@ -55,10 +65,4 @@ namespace chatllm::ernie::moe
5565
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config);
5666
void load(ModelLoader &loader);
5767
};
58-
}
59-
60-
namespace chatllm
61-
{
62-
REGISTER_MODEL_LOADER(ERNIE_DENSE, ernie::dense, 1);
63-
REGISTER_MODEL_LOADER(ERNIE_MOE, ernie::moe, 1);
6468
}

src/tokenizer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,9 @@ int Processor::PieceToId(std::string_view piece) const
273273

274274
const std::string Processor::IdToPiece(int id) const
275275
{
276-
if (token_override.contains(id))
277-
return token_override.find(id)->second;
276+
auto iter = token_override.find(id);
277+
if (iter != token_override.end())
278+
return iter->second;
278279

279280
if (id < 0) return token_unk_id;
280281
return id < (int)vocab_.id_to_token.size() ? vocab_.id_to_token[id].tok : token_unk_id;

0 commit comments

Comments
 (0)