33#include < cstring>
44#include < functional>
55#include " deepseek.h"
6- #include " qwen.h"
76
87namespace chatllm ::bailing
98{
@@ -60,6 +59,10 @@ namespace chatllm::bailing::moe
6059 eos_token_id = tp->PieceToId (" <|role_end|>" );
6160 mask_token_id = tp->PieceToId (" <|mask|>" );
6261
62+ // LlaDA might generate lots of PAD
63+ if (mask_token_id >= 0 )
64+ terminate_ids.insert (pad_token_id);
65+
6366 int t = tp->PieceToId (" <think>" );
6467 if (t >= 0 )
6568 {
@@ -180,15 +183,55 @@ namespace chatllm::bailing::moe2
180183 return selected_experts;
181184 }
182185
186+ class AttnParams
187+ {
188+ public:
189+ static int custom_mask;
190+ };
191+
192+ int AttnParams::custom_mask = false ;
193+
194+ class SelfAttention : public QKNormedAttention <RMSNorm, BaseAttention>
195+ {
196+ public:
197+ SelfAttention (InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int head_dim, int max_length):
198+ QKNormedAttention<RMSNorm, BaseAttention>(ctx, hidden_size, num_attention_heads, num_kv_heads, head_dim, max_length, false , false ),
199+ mask (nullptr )
200+ {
201+ if (AttnParams::custom_mask)
202+ {
203+ // reverse some data
204+ mask = ggml::new_tensor_2d (ctx, ggml::type::GGML_TYPE_F16, max_length, 32 );
205+ ctx->get_allocator ()->alloc (mask);
206+ }
207+ }
208+
209+ ggml::tensor *attn_scores_to_probs (ComputeContext *ctx, int hidden_size, const int n_past, const int qlen,
210+ ggml::tensor *attn_scores) override
211+ {
212+ const int head_size = hidden_size / num_attention_heads;
213+
214+ ggml::tensor * sub_mask = mask ? ggml::view_2d (ctx, mask, n_past + qlen, qlen, (n_past + qlen) * ggml::element_size (mask), 0 ) : nullptr ;
215+
216+ // attn_probs = soft_max(attn_masked)
217+ ggml::tensor * attn_probs = ggml::soft_max_ext (ctx, attn_scores, sub_mask, 1 .f / sqrtf ((float )head_size), 0 .0f );
218+
219+ return attn_probs;
220+ }
221+ public:
222+ ggml_tensor *mask;
223+ };
224+
183225 class ConditionalGeneration : public BaseModelForConditionalGeneration
184226 {
185227 public:
186228 typedef CombinedMLP<BailingSparseMoE, SiLUMLP> BailingMoEMLP;
187- typedef LMBlock1<RMSNorm, qwen::v3::QWen3SelfAttention, RMSNorm, BailingMoEMLP> BailingMoEBlock;
229+ typedef LMBlock1<RMSNorm, SelfAttention, RMSNorm, BailingMoEMLP> BailingMoEBlock;
230+ typedef LMBlock1<RMSNorm, SelfAttention, RMSNorm, SiLUMLP> BailingDenseBlock;
188231 typedef BaseModelForConditionalGeneration Base;
189232 typedef HeterogeneousModel ModelClass;
190233 public:
191- ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAILING_MOE2, bool causal = true )
234+ ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAILING_MOE2)
192235 : BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 4 ),
193236 config (config)
194237 {
@@ -197,7 +240,8 @@ namespace chatllm::bailing::moe2
197240 const int dense_layer_num = config.num_hidden_layers - moe_layer_num;
198241 const size_t num_tensors = 3
199242 + moe_layer_num * (12 + 7 )
200- + dense_layer_num * 14 ;
243+ + dense_layer_num * 14
244+ + (AttnParams::custom_mask ? config.num_hidden_layers : 0 );
201245 const size_t ctx_size = num_tensors * tensor_ovhd;
202246 w_ctx_.gctx = GGMLContext ({.mem_size = ctx_size, .mem_buffer = nullptr , .no_alloc = true });
203247 w_ctx_.dtype = config.dtype ;
@@ -223,15 +267,13 @@ namespace chatllm::bailing::moe2
223267 layer->mlp .mlp1 .routed_scaling_factor = config.routed_scaling_factor ;
224268 layer->mlp .mlp1 .n_group = config.n_group ;
225269 layer->mlp .mlp1 .topk_group = config.topk_group ;
226- layer->attention .causal = causal;
227270 config_rope (layer->attention );
228271 return layer;
229272 }
230273 else
231274 {
232- auto layer = new qwen::v3::QWen3Block (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size ,
275+ auto layer = new BailingDenseBlock (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size ,
233276 config.num_key_value_heads , config.head_dim , config.max_length );
234- layer->attention .causal = causal;
235277 config_rope (layer->attention );
236278 return layer;
237279 }
@@ -249,6 +291,20 @@ namespace chatllm::bailing::moe2
249291 w_ctx_.check_used_mem_size (true );
250292 }
251293
294+ SelfAttention *get_attn_of_layer (int layer_index)
295+ {
296+ if (is_layer_moe (layer_index))
297+ {
298+ auto layer = dynamic_cast <BailingMoEBlock *>(transformer->get_layer (layer_index));
299+ return &layer->attention ;
300+ }
301+ else
302+ {
303+ auto layer = dynamic_cast <BailingDenseBlock *>(transformer->get_layer (layer_index));
304+ return &layer->attention ;
305+ }
306+ }
307+
252308 void load (ModelLoader &loader) override
253309 {
254310 loader.add_tensor_name_translations ({
@@ -287,11 +343,21 @@ namespace chatllm::bailing::llada
287343 typedef moe2::Config Config;
288344 typedef moe2::Tokenizer Tokenizer;
289345
290- class ConditionalGeneration : public moe2 ::ConditionalGeneration
346+ class Prelude
347+ {
348+ public:
349+ Prelude ()
350+ {
351+ moe2::AttnParams::custom_mask = true ;
352+ }
353+ };
354+
355+ class ConditionalGeneration : public Prelude , public moe2 ::ConditionalGeneration
291356 {
292357 public:
293358 ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = (ModelType)MODEL_TYPE_LLADA2):
294- moe2::ConditionalGeneration (config, runtime_config, type, false )
359+ Prelude (),
360+ moe2::ConditionalGeneration (config, runtime_config, type)
295361 {
296362 final_steps = dynamic_cast <LMFinalSteps *>(transformer->get_final_steps ());
297363 }
@@ -304,13 +370,13 @@ namespace chatllm::bailing::llada
304370 BaseStreamer *streamer = nullptr ) override ;
305371 protected:
306372 bool generate_next_block (const int *input_ids, const int ids_count, const GenerationConfig &gen_config,
307- std::vector<float > *logits_output,
308- std::vector<int > *logits_orderring, const int batch_size = 1 );
373+ std::vector<float > *logits_output, const int batch_size = 1 );
309374 bool run_model (const int *input_ids, const int ids_count,
310375 const GenerationConfig &gen_config,
311376 int past,
312- std::vector<float > *logits_output,
313- std::vector<int > *logits_orderring, const int batch_size);
377+ std::vector<float > *logits_output, const int batch_size);
378+
379+ void update_mask (int past, int qlen);
314380 public:
315381 int block_length = 32 ;
316382 int steps = 32 ;
@@ -330,14 +396,57 @@ namespace chatllm::bailing::llada
330396 threshold = (float )utils::get_opt (args, " threshold" , threshold);
331397 final_steps->set_read_last_n (block_length);
332398 steps = std::min (steps, block_length);
399+
400+
401+ }
402+
403+ void ConditionalGeneration::update_mask (int past, int qlen)
404+ {
405+ CHATLLM_CHECK ((past % block_length) == 0 );
406+ CHATLLM_CHECK ((qlen % block_length) == 0 );
407+
408+ std::vector<float > mask;
409+ const int col_num = qlen + past;
410+ mask.resize (qlen * col_num);
411+ for (int i = 0 ; i < qlen / block_length; i++)
412+ {
413+ for (int j = 0 ; j < col_num / block_length; j++)
414+ {
415+ const float v = i + (past / block_length) >= j ? 0 .0f : -INFINITY;
416+ for (int ii = 0 ; ii < block_length; ii++)
417+ {
418+ const int row_index = i * block_length + ii;
419+ for (int jj = 0 ; jj < block_length; jj++)
420+ {
421+ const int col_index = j * block_length + jj;
422+ const int index = row_index * col_num + col_index;
423+ mask[index] = v;
424+ }
425+ }
426+ }
427+ }
428+
429+ std::vector<uint8_t > buf;
430+ buf.resize (ggml::element_size (get_attn_of_layer (0 )->mask ) * mask.size ());
431+ ggml::from_float (ggml::type_of (get_attn_of_layer (0 )->mask ), mask.data (), buf.data (), mask.size (), 1 );
432+
433+ for (int i = 0 ; i < config.num_hidden_layers ; i++)
434+ {
435+ auto m = get_attn_of_layer (i)->mask ;
436+ ggml::set_dim (m, 0 , col_num);
437+ ggml::set_dim (m, 1 , qlen);
438+ Backend::write_tensor_data (m, buf.data (), 0 , buf.size ());
439+ }
333440 }
334441
335442 bool ConditionalGeneration::run_model (const int *input_ids, const int ids_count,
336443 const GenerationConfig &gen_config,
337444 int past,
338- std::vector<float > *logits_output,
339- std::vector<int > *logits_orderring, const int batch_size)
445+ std::vector<float > *logits_output, const int batch_size)
340446 {
447+ CHATLLM_CHECK (batch_size == 1 );
448+ CHATLLM_CHECK ((ids_count % block_length) == 0 );
449+
341450 if (!initial_run)
342451 {
343452 initial_run = true ;
@@ -349,16 +458,18 @@ namespace chatllm::bailing::llada
349458
350459 ForwardContext ctx (&backend_context);
351460 ctx.user_options = w_ctx_.user_options ;
461+ LMFinalStepsDisabler disabler (final_steps, logits_output == nullptr );
352462
353463 ctx.gctx = GGMLContext ({.mem_size = backend_context.buf_compute_meta .size (), .mem_buffer = backend_context.buf_compute_meta .data (), .no_alloc = true });
354464 ctx.gf = ggml::new_graph_custom (&ctx, GRAPH_SIZE, false );
355465
356466 set_dbg_ctx (&ctx);
357467
468+ update_mask (past, ids_count);
469+
358470 ctx.move_to_layer (LayerAllocatorManager::MiscLayer::Prolog);
359471 ggml::tensor *input_ids_tensor = ggml::new_tensor_2d (&ctx, GGML_TYPE_I32, ids_count, batch_size);
360472
361- final_steps->set_do_orderring (logits_orderring != nullptr );
362473 ggml::tensor *r = transformer->forward (&ctx, input_ids_tensor, past);
363474
364475 ctx.move_to_layer (LayerAllocatorManager::MiscLayer::Epilog);
@@ -372,8 +483,6 @@ namespace chatllm::bailing::llada
372483
373484 if (logits_output)
374485 logits_output->resize (ggml::nelements (r));
375- if (logits_orderring)
376- logits_orderring->resize (ggml::nelements (final_steps->get_orderring_result ()));
377486
378487 if (!ctx.allocate ()) return false ;
379488
@@ -389,17 +498,14 @@ namespace chatllm::bailing::llada
389498
390499 if (logits_output)
391500 Backend::read_tensor_data (r, logits_output->data ());
392- if (logits_orderring)
393- Backend::read_tensor_data (final_steps->get_orderring_result (), logits_orderring->data ());
394501
395502 ctx.reset ();
396503
397504 return true ;
398505 }
399506
400507 bool ConditionalGeneration::generate_next_block (const int *input_ids, const int ids_count, const GenerationConfig &gen_config,
401- std::vector<float > *logits_output,
402- std::vector<int > *logits_orderring, const int batch_size)
508+ std::vector<float > *logits_output, const int batch_size)
403509 {
404510 int batch = batch_input > 1 ? batch_input : 1 ;
405511 batch = (batch / block_length) * block_length;
@@ -411,11 +517,11 @@ namespace chatllm::bailing::llada
411517
412518 for (; (remain > batch) && !aborted; p += batch, remain -= batch, past += batch)
413519 {
414- if (!run_model (p, batch, gen_config, past, nullptr , nullptr , 1 ))
520+ if (!run_model (p, batch, gen_config, past, nullptr , batch_size ))
415521 return false ;
416522 }
417523
418- return run_model (p, remain, gen_config, past, logits_output, logits_orderring, 1 );
524+ return run_model (p, remain, gen_config, past, logits_output, batch_size );
419525 }
420526
421527 std::vector<int > ConditionalGeneration::generate (const std::vector<int > &input_ids, const GenerationConfig &gen_config,
@@ -506,7 +612,7 @@ namespace chatllm::bailing::llada
506612 const int prefill_len = prefill_block_num * block_length;
507613 if (prefill_len > 0 )
508614 {
509- generate_next_block (curr_input_ids.data (), prefill_len, gen_config, nullptr , nullptr );
615+ generate_next_block (curr_input_ids.data (), prefill_len, gen_config, nullptr );
510616 n_past += prefill_len;
511617 curr_input_ids.erase (curr_input_ids.begin (), curr_input_ids.begin () + prefill_len);
512618
@@ -526,7 +632,7 @@ namespace chatllm::bailing::llada
526632 {
527633 // Note: we have to run a whole block again and again.
528634 std::vector<float > lm_logits;
529- generate_next_block (block_result.data (), block_length, gen_config, &lm_logits, nullptr );
635+ generate_next_block (block_result.data (), block_length, gen_config, &lm_logits);
530636
531637 struct candidate
532638 {
@@ -591,8 +697,10 @@ namespace chatllm::bailing::llada
591697
592698 // block is now finalized
593699 if (next_pos_to_add == block_length)
594- generate_next_block (block_result.data (), next_pos_to_add, gen_config, nullptr , nullptr );
595- n_past += next_pos_to_add;
700+ {
701+ generate_next_block (block_result.data (), next_pos_to_add, gen_config, nullptr );
702+ n_past += next_pos_to_add;
703+ }
596704
597705 if (performance)
598706 performance->Accumulate (ModelPerfInfo::Type::Generation, block_length - block_prefilled_size);
0 commit comments