11#include " deepseek.h"
2+ #include " qwen.h"
23
34namespace chatllm ::bailing::moe
45{
@@ -90,6 +91,171 @@ namespace chatllm::bailing::moe
9091 : deepseek::v1_moe::ConditionalGeneration0<NUM_EXPERTS, EXPERTS_PER_TOK, EXPERTS_PER_TOK>(config, runtime_config, MODEL_TYPE_BAILINGMOE, config.head_dim)
9192 {}
9293 };
94+ }
9395
94- REGISTER_MODEL_LOADER (BAILINGMOE, bailing::moe, 1 );
96+ namespace chatllm ::bailing2::moe
97+ {
98+ struct Config : public bailing ::moe::Config
99+ {
100+ int rope_dim;
101+ int n_group;
102+ int topk_group;
103+ float routed_scaling_factor;
104+ };
105+
106+ typedef bailing::moe::Tokenizer Tokenizer;
107+
108+ const int NUM_EXPERTS = 256 ;
109+ const int EXPERTS_PER_TOK = 8 ;
110+
111+ class BailingSparseMoE : public BaseSparseMLP
112+ {
113+ public:
114+ BailingSparseMoE (InitContext *ctx, int hidden_size, int intermediate_size, int num_experts = NUM_EXPERTS, int experts_per_tok = EXPERTS_PER_TOK)
115+ : BaseSparseMLP(ctx, hidden_size, intermediate_size, num_experts, experts_per_tok, ActFunc::SILU, true ),
116+ n_group (-1 ), topk_group(-1 )
117+ {
118+ score_func = ScoreFunc::Sigmoid;
119+ always_scaling = true ;
120+ }
121+ protected:
122+ ggml::tensor *select_experts (ComputeContext *ctx, ggml::tensor *corrected_score) override ;
123+
124+ public:
125+ int n_group;
126+ int topk_group;
127+ };
128+
129+ ggml::tensor *BailingSparseMoE::select_experts (ComputeContext *ctx, ggml::tensor *corrected_score)
130+ {
131+ const int n_expert = num_local_experts;
132+ const int experts_per_group = n_expert / n_group;
133+ CHATLLM_CHECK (ggml::get_dim (corrected_score, 2 ) == 1 );
134+
135+ ggml::tensor * selected_experts = nullptr ;
136+
137+ ggml::tensor *grouped_scores = ggml::reshape_4d (ctx, corrected_score, experts_per_group, num_experts_per_tok,
138+ ggml::get_dim (corrected_score, 1 ), ggml::get_dim (corrected_score, 2 ));
139+ selected_experts = ggml::top_k (ctx, grouped_scores, topk_group);
140+
141+ ggml::tensor *selected_experts_i64 = ggml::cast_int_to_i64 (ctx, selected_experts);
142+
143+ CHATLLM_CHECK (ggml::get_dim (grouped_scores, 3 ) == 1 );
144+ grouped_scores = ggml::reshape_4d (ctx, grouped_scores, 1 , ggml::get_dim (grouped_scores, 0 ), ggml::get_dim (grouped_scores, 1 ), ggml::get_dim (grouped_scores, 2 ));
145+ ggml::tensor *selected_group_scores = ggml::scale (ctx, grouped_scores, 0 .0f );
146+ grouped_scores = ggml::get_rows (ctx, grouped_scores, selected_experts);
147+ selected_group_scores = ggml::set_rows (ctx, selected_group_scores, selected_experts_i64, grouped_scores);
148+
149+ selected_group_scores = ggml::reshape_3d (ctx, selected_group_scores,
150+ ggml::get_dim (corrected_score, 0 ), ggml::get_dim (corrected_score, 1 ), ggml::get_dim (corrected_score, 2 ));
151+
152+ selected_experts = ggml::top_k (ctx, selected_group_scores, num_experts_per_tok);
153+
154+ return selected_experts;
155+ }
156+
157+ class ConditionalGeneration : public BaseModelForConditionalGeneration
158+ {
159+ public:
160+ typedef CombinedMLP<BailingSparseMoE, SiLUMLP> BailingMoEMLP;
161+ typedef LMBlock1<RMSNorm, qwen::v3::QWen3SelfAttention, RMSNorm, BailingMoEMLP> BailingMoEBlock;
162+ typedef BaseModelForConditionalGeneration Base;
163+ typedef HeterogeneousModel ModelClass;
164+ public:
165+ ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAILING_MOE2)
166+ : BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 4 ),
167+ config (config)
168+ {
169+ const size_t tensor_ovhd = ggml_tensor_overhead ();
170+ const int moe_layer_num = get_moe_layer_num ();
171+ const int dense_layer_num = config.num_hidden_layers - moe_layer_num;
172+ const size_t num_tensors = 3
173+ + moe_layer_num * (12 + 7 )
174+ + dense_layer_num * 14 ;
175+ const size_t ctx_size = num_tensors * tensor_ovhd;
176+ w_ctx_.gctx = GGMLContext ({.mem_size = ctx_size, .mem_buffer = nullptr , .no_alloc = true });
177+ w_ctx_.dtype = config.dtype ;
178+
179+ CHATLLM_CHECK ((NUM_EXPERTS == config.n_routed_experts )
180+ && (EXPERTS_PER_TOK == config.num_experts_per_tok ))
181+ << " unsupported MoE param" ;
182+
183+ #define config_rope (attention ) do { \
184+ attention.freq_base = config.rope_theta ; \
185+ attention.rope_dim = config.rope_dim ; \
186+ } while (false )
187+
188+ auto create_layer = [&](InitContext *ctx, int layer_index) -> Block * {
189+ if (is_layer_moe (layer_index))
190+ {
191+ auto layer = new BailingMoEBlock (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size ,
192+ config.moe_intermediate_size , config.moe_intermediate_size * config.n_shared_experts ,
193+ config.num_key_value_heads ,
194+ config.head_dim ,
195+ config.max_length );
196+ layer->mlp .mlp1 .norm_topk_prob = config.norm_topk_prob != 0 ;
197+ layer->mlp .mlp1 .routed_scaling_factor = config.routed_scaling_factor ;
198+ layer->mlp .mlp1 .n_group = config.n_group ;
199+ layer->mlp .mlp1 .topk_group = config.topk_group ;
200+ config_rope (layer->attention );
201+ return layer;
202+ }
203+ else
204+ {
205+ auto layer = new qwen::v3::QWen3Block (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size ,
206+ config.num_key_value_heads , config.head_dim , config.max_length );
207+ config_rope (layer->attention );
208+ return layer;
209+ }
210+ };
211+
212+ auto transformer = new ModelClass (&w_ctx_, config.num_hidden_layers , config.hidden_size ,
213+ create_embedding<Embedding>(&w_ctx_, config),
214+ create_final_norm<RMSNorm>(&w_ctx_, config),
215+ create_lm_head (&w_ctx_, config, false ), create_layer);
216+
217+ Base::transformer = transformer;
218+
219+ #undef config_rope
220+
221+ w_ctx_.check_used_mem_size (true );
222+ }
223+
224+ void load (ModelLoader &loader) override
225+ {
226+ loader.add_tensor_name_translations ({
227+ {" .mlp2." , " .shared_experts." },
228+ {" .mlp1.gate." , " .gate." },
229+ {" .mlp1.experts." , " .experts." },
230+ {" .mlp1.gate_score_correction_bias" , " .gate.expert_bias" },
231+ });
232+
233+ BaseModelForConditionalGeneration::load (loader);
234+ }
235+
236+ public:
237+ const Config config;
238+
239+ bool is_layer_moe (int layer_index)
240+ {
241+ return (layer_index >= config.first_k_dense_replace ) && (layer_index % config.moe_layer_freq == 0 );
242+ }
243+
244+ int get_moe_layer_num ()
245+ {
246+ int r = 0 ;
247+ for (int i = 0 ; i < config.num_hidden_layers ; i++)
248+ {
249+ if (is_layer_moe (i))
250+ r++;
251+ }
252+ return r;
253+ }
254+ };
255+ }
256+
257+ namespace chatllm
258+ {
259+ REGISTER_MODEL_LOADER (BAILINGMOE, bailing::moe, 1 );
260+ REGISTER_MODEL_LOADER (BAILING_MOE2, bailing2::moe, 1 );
95261}
0 commit comments