Skip to content

Commit f2421c6

Browse files
committed
correct implementation of mask for LLaDA and little improvement
1 parent f4ec257 commit f2421c6

File tree

3 files changed

+159
-28
lines changed

3 files changed

+159
-28
lines changed

models/bailing.cpp

Lines changed: 136 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <cstring>
44
#include <functional>
55
#include "deepseek.h"
6-
#include "qwen.h"
76

87
namespace chatllm::bailing
98
{
@@ -60,6 +59,10 @@ namespace chatllm::bailing::moe
6059
eos_token_id = tp->PieceToId("<|role_end|>");
6160
mask_token_id = tp->PieceToId("<|mask|>");
6261

62+
// LlaDA might generate lots of PAD
63+
if (mask_token_id >= 0)
64+
terminate_ids.insert(pad_token_id);
65+
6366
int t = tp->PieceToId("<think>");
6467
if (t >= 0)
6568
{
@@ -180,15 +183,55 @@ namespace chatllm::bailing::moe2
180183
return selected_experts;
181184
}
182185

186+
class AttnParams
187+
{
188+
public:
189+
static int custom_mask;
190+
};
191+
192+
int AttnParams::custom_mask = false;
193+
194+
class SelfAttention : public QKNormedAttention<RMSNorm, BaseAttention>
195+
{
196+
public:
197+
SelfAttention(InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int head_dim, int max_length):
198+
QKNormedAttention<RMSNorm, BaseAttention>(ctx, hidden_size, num_attention_heads, num_kv_heads, head_dim, max_length, false, false),
199+
mask(nullptr)
200+
{
201+
if (AttnParams::custom_mask)
202+
{
203+
// reverse some data
204+
mask = ggml::new_tensor_2d(ctx, ggml::type::GGML_TYPE_F16, max_length, 32);
205+
ctx->get_allocator()->alloc(mask);
206+
}
207+
}
208+
209+
ggml::tensor *attn_scores_to_probs(ComputeContext *ctx, int hidden_size, const int n_past, const int qlen,
210+
ggml::tensor *attn_scores) override
211+
{
212+
const int head_size = hidden_size / num_attention_heads;
213+
214+
ggml::tensor * sub_mask = mask ? ggml::view_2d(ctx, mask, n_past + qlen, qlen, (n_past + qlen) * ggml::element_size(mask), 0) : nullptr;
215+
216+
// attn_probs = soft_max(attn_masked)
217+
ggml::tensor * attn_probs = ggml::soft_max_ext(ctx, attn_scores, sub_mask, 1.f / sqrtf((float)head_size), 0.0f);
218+
219+
return attn_probs;
220+
}
221+
public:
222+
ggml_tensor *mask;
223+
};
224+
183225
class ConditionalGeneration : public BaseModelForConditionalGeneration
184226
{
185227
public:
186228
typedef CombinedMLP<BailingSparseMoE, SiLUMLP> BailingMoEMLP;
187-
typedef LMBlock1<RMSNorm, qwen::v3::QWen3SelfAttention, RMSNorm, BailingMoEMLP> BailingMoEBlock;
229+
typedef LMBlock1<RMSNorm, SelfAttention, RMSNorm, BailingMoEMLP> BailingMoEBlock;
230+
typedef LMBlock1<RMSNorm, SelfAttention, RMSNorm, SiLUMLP> BailingDenseBlock;
188231
typedef BaseModelForConditionalGeneration Base;
189232
typedef HeterogeneousModel ModelClass;
190233
public:
191-
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAILING_MOE2, bool causal = true)
234+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAILING_MOE2)
192235
: BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 4),
193236
config(config)
194237
{
@@ -197,7 +240,8 @@ namespace chatllm::bailing::moe2
197240
const int dense_layer_num = config.num_hidden_layers - moe_layer_num;
198241
const size_t num_tensors = 3
199242
+ moe_layer_num * (12 + 7)
200-
+ dense_layer_num * 14;
243+
+ dense_layer_num * 14
244+
+ (AttnParams::custom_mask ? config.num_hidden_layers : 0);
201245
const size_t ctx_size = num_tensors * tensor_ovhd;
202246
w_ctx_.gctx = GGMLContext({.mem_size = ctx_size, .mem_buffer = nullptr, .no_alloc = true});
203247
w_ctx_.dtype = config.dtype;
@@ -223,15 +267,13 @@ namespace chatllm::bailing::moe2
223267
layer->mlp.mlp1.routed_scaling_factor = config.routed_scaling_factor;
224268
layer->mlp.mlp1.n_group = config.n_group;
225269
layer->mlp.mlp1.topk_group = config.topk_group;
226-
layer->attention.causal = causal;
227270
config_rope(layer->attention);
228271
return layer;
229272
}
230273
else
231274
{
232-
auto layer = new qwen::v3::QWen3Block(ctx, config.hidden_size, config.num_attention_heads, config.intermediate_size,
275+
auto layer = new BailingDenseBlock(ctx, config.hidden_size, config.num_attention_heads, config.intermediate_size,
233276
config.num_key_value_heads, config.head_dim, config.max_length);
234-
layer->attention.causal = causal;
235277
config_rope(layer->attention);
236278
return layer;
237279
}
@@ -249,6 +291,20 @@ namespace chatllm::bailing::moe2
249291
w_ctx_.check_used_mem_size(true);
250292
}
251293

294+
SelfAttention *get_attn_of_layer(int layer_index)
295+
{
296+
if (is_layer_moe(layer_index))
297+
{
298+
auto layer = dynamic_cast<BailingMoEBlock *>(transformer->get_layer(layer_index));
299+
return &layer->attention;
300+
}
301+
else
302+
{
303+
auto layer = dynamic_cast<BailingDenseBlock *>(transformer->get_layer(layer_index));
304+
return &layer->attention;
305+
}
306+
}
307+
252308
void load(ModelLoader &loader) override
253309
{
254310
loader.add_tensor_name_translations({
@@ -287,11 +343,21 @@ namespace chatllm::bailing::llada
287343
typedef moe2::Config Config;
288344
typedef moe2::Tokenizer Tokenizer;
289345

290-
class ConditionalGeneration : public moe2::ConditionalGeneration
346+
class Prelude
347+
{
348+
public:
349+
Prelude()
350+
{
351+
moe2::AttnParams::custom_mask = true;
352+
}
353+
};
354+
355+
class ConditionalGeneration : public Prelude, public moe2::ConditionalGeneration
291356
{
292357
public:
293358
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = (ModelType)MODEL_TYPE_LLADA2):
294-
moe2::ConditionalGeneration(config, runtime_config, type, false)
359+
Prelude(),
360+
moe2::ConditionalGeneration(config, runtime_config, type)
295361
{
296362
final_steps = dynamic_cast<LMFinalSteps *>(transformer->get_final_steps());
297363
}
@@ -304,13 +370,13 @@ namespace chatllm::bailing::llada
304370
BaseStreamer *streamer = nullptr) override;
305371
protected:
306372
bool generate_next_block(const int *input_ids, const int ids_count, const GenerationConfig &gen_config,
307-
std::vector<float> *logits_output,
308-
std::vector<int> *logits_orderring, const int batch_size = 1);
373+
std::vector<float> *logits_output, const int batch_size = 1);
309374
bool run_model(const int *input_ids, const int ids_count,
310375
const GenerationConfig &gen_config,
311376
int past,
312-
std::vector<float> *logits_output,
313-
std::vector<int> *logits_orderring, const int batch_size);
377+
std::vector<float> *logits_output, const int batch_size);
378+
379+
void update_mask(int past, int qlen);
314380
public:
315381
int block_length = 32;
316382
int steps = 32;
@@ -330,14 +396,57 @@ namespace chatllm::bailing::llada
330396
threshold = (float)utils::get_opt(args, "threshold", threshold);
331397
final_steps->set_read_last_n(block_length);
332398
steps = std::min(steps, block_length);
399+
400+
401+
}
402+
403+
void ConditionalGeneration::update_mask(int past, int qlen)
404+
{
405+
CHATLLM_CHECK((past % block_length) == 0);
406+
CHATLLM_CHECK((qlen % block_length) == 0);
407+
408+
std::vector<float> mask;
409+
const int col_num = qlen + past;
410+
mask.resize(qlen * col_num);
411+
for (int i = 0; i < qlen / block_length; i++)
412+
{
413+
for (int j = 0; j < col_num / block_length; j++)
414+
{
415+
const float v = i + (past / block_length) >= j ? 0.0f : -INFINITY;
416+
for (int ii = 0; ii < block_length; ii++)
417+
{
418+
const int row_index = i * block_length + ii;
419+
for (int jj = 0; jj < block_length; jj++)
420+
{
421+
const int col_index = j * block_length + jj;
422+
const int index = row_index * col_num + col_index;
423+
mask[index] = v;
424+
}
425+
}
426+
}
427+
}
428+
429+
std::vector<uint8_t> buf;
430+
buf.resize(ggml::element_size(get_attn_of_layer(0)->mask) * mask.size());
431+
ggml::from_float(ggml::type_of(get_attn_of_layer(0)->mask), mask.data(), buf.data(), mask.size(), 1);
432+
433+
for (int i = 0; i < config.num_hidden_layers; i++)
434+
{
435+
auto m = get_attn_of_layer(i)->mask;
436+
ggml::set_dim(m, 0, col_num);
437+
ggml::set_dim(m, 1, qlen);
438+
Backend::write_tensor_data(m, buf.data(), 0, buf.size());
439+
}
333440
}
334441

335442
bool ConditionalGeneration::run_model(const int *input_ids, const int ids_count,
336443
const GenerationConfig &gen_config,
337444
int past,
338-
std::vector<float> *logits_output,
339-
std::vector<int> *logits_orderring, const int batch_size)
445+
std::vector<float> *logits_output, const int batch_size)
340446
{
447+
CHATLLM_CHECK(batch_size == 1);
448+
CHATLLM_CHECK((ids_count % block_length) == 0);
449+
341450
if (!initial_run)
342451
{
343452
initial_run = true;
@@ -349,16 +458,18 @@ namespace chatllm::bailing::llada
349458

350459
ForwardContext ctx(&backend_context);
351460
ctx.user_options = w_ctx_.user_options;
461+
LMFinalStepsDisabler disabler(final_steps, logits_output == nullptr);
352462

353463
ctx.gctx = GGMLContext({.mem_size = backend_context.buf_compute_meta.size(), .mem_buffer = backend_context.buf_compute_meta.data(), .no_alloc = true});
354464
ctx.gf = ggml::new_graph_custom(&ctx, GRAPH_SIZE, false);
355465

356466
set_dbg_ctx(&ctx);
357467

468+
update_mask(past, ids_count);
469+
358470
ctx.move_to_layer(LayerAllocatorManager::MiscLayer::Prolog);
359471
ggml::tensor *input_ids_tensor = ggml::new_tensor_2d(&ctx, GGML_TYPE_I32, ids_count, batch_size);
360472

361-
final_steps->set_do_orderring(logits_orderring != nullptr);
362473
ggml::tensor *r = transformer->forward(&ctx, input_ids_tensor, past);
363474

364475
ctx.move_to_layer(LayerAllocatorManager::MiscLayer::Epilog);
@@ -372,8 +483,6 @@ namespace chatllm::bailing::llada
372483

373484
if (logits_output)
374485
logits_output->resize(ggml::nelements(r));
375-
if (logits_orderring)
376-
logits_orderring->resize(ggml::nelements(final_steps->get_orderring_result()));
377486

378487
if (!ctx.allocate()) return false;
379488

@@ -389,17 +498,14 @@ namespace chatllm::bailing::llada
389498

390499
if (logits_output)
391500
Backend::read_tensor_data(r, logits_output->data());
392-
if (logits_orderring)
393-
Backend::read_tensor_data(final_steps->get_orderring_result(), logits_orderring->data());
394501

395502
ctx.reset();
396503

397504
return true;
398505
}
399506

400507
bool ConditionalGeneration::generate_next_block(const int *input_ids, const int ids_count, const GenerationConfig &gen_config,
401-
std::vector<float> *logits_output,
402-
std::vector<int> *logits_orderring, const int batch_size)
508+
std::vector<float> *logits_output, const int batch_size)
403509
{
404510
int batch = batch_input > 1 ? batch_input : 1;
405511
batch = (batch / block_length) * block_length;
@@ -411,11 +517,11 @@ namespace chatllm::bailing::llada
411517

412518
for (; (remain > batch) && !aborted; p += batch, remain -= batch, past += batch)
413519
{
414-
if (!run_model(p, batch, gen_config, past, nullptr, nullptr, 1))
520+
if (!run_model(p, batch, gen_config, past, nullptr, batch_size))
415521
return false;
416522
}
417523

418-
return run_model(p, remain, gen_config, past, logits_output, logits_orderring, 1);
524+
return run_model(p, remain, gen_config, past, logits_output, batch_size);
419525
}
420526

421527
std::vector<int> ConditionalGeneration::generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
@@ -506,7 +612,7 @@ namespace chatllm::bailing::llada
506612
const int prefill_len = prefill_block_num * block_length;
507613
if (prefill_len > 0)
508614
{
509-
generate_next_block(curr_input_ids.data(), prefill_len, gen_config, nullptr, nullptr);
615+
generate_next_block(curr_input_ids.data(), prefill_len, gen_config, nullptr);
510616
n_past += prefill_len;
511617
curr_input_ids.erase(curr_input_ids.begin(), curr_input_ids.begin() + prefill_len);
512618

@@ -526,7 +632,7 @@ namespace chatllm::bailing::llada
526632
{
527633
// Note: we have to run a whole block again and again.
528634
std::vector<float> lm_logits;
529-
generate_next_block(block_result.data(), block_length, gen_config, &lm_logits, nullptr);
635+
generate_next_block(block_result.data(), block_length, gen_config, &lm_logits);
530636

531637
struct candidate
532638
{
@@ -591,8 +697,10 @@ namespace chatllm::bailing::llada
591697

592698
// block is now finalized
593699
if (next_pos_to_add == block_length)
594-
generate_next_block(block_result.data(), next_pos_to_add, gen_config, nullptr, nullptr);
595-
n_past += next_pos_to_add;
700+
{
701+
generate_next_block(block_result.data(), next_pos_to_add, gen_config, nullptr);
702+
n_past += next_pos_to_add;
703+
}
596704

597705
if (performance)
598706
performance->Accumulate(ModelPerfInfo::Type::Generation, block_length - block_prefilled_size);

src/models.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,8 @@ namespace chatllm
14771477
const int last_n = qlen >= this->last_n ? this->last_n : qlen;
14781478
order = nullptr;
14791479

1480+
if (disable_head) return hidden_states;
1481+
14801482
hidden_states = ggml::view_3d(ctx, hidden_states, model->hidden_size, last_n, batch,
14811483
ggml::row_size(hidden_states),
14821484
ggml::row_size(hidden_states) * qlen,

src/models.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ namespace chatllm
9999
std::unique_ptr<ModelFinalSteps> final_steps;
100100
};
101101

102+
class LMFinalStepsDisabler;
103+
102104
class LMFinalSteps : public ModelFinalSteps
103105
{
104106
public:
107+
friend LMFinalStepsDisabler;
105108
ggml::tensor *forward(HeterogeneousModel *model, ComputeContext *ctx, ggml::tensor *input_ids, ggml::tensor *hidden_states) override;
106109
void set_read_last_n(int n);
107110
void set_do_orderring(bool flag); // descending
@@ -110,6 +113,24 @@ namespace chatllm
110113
bool do_orderring = false;
111114
int last_n = 1;
112115
ggml::tensor *order= nullptr;
116+
bool disable_head = false;
117+
};
118+
119+
class LMFinalStepsDisabler
120+
{
121+
public:
122+
LMFinalStepsDisabler(LMFinalSteps *target, bool active = true) : target(target), state(target->disable_head)
123+
{
124+
if (active)
125+
target->disable_head = false;
126+
}
127+
~LMFinalStepsDisabler()
128+
{
129+
target->disable_head = state;
130+
}
131+
private:
132+
LMFinalSteps *target;
133+
bool state;
113134
};
114135

115136
class EmbeddingPoolingFinalSteps : public ModelFinalSteps

0 commit comments

Comments
 (0)