Skip to content

Commit d4d465b

Browse files
committed
graph : support cacheless embeddings with FA and iSWA
1 parent 41aac5c commit d4d465b

File tree

4 files changed

+105
-48
lines changed

4 files changed

+105
-48
lines changed

src/llama-graph.cpp

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261
}
262262
}
263263

264-
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266-
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267-
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268-
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269-
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
266+
const char * swa_type_str = "unknown";
267+
268+
switch (swa_type) {
269+
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270+
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271+
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272+
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273+
};
274+
270275
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271276
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272277
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,88 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300
const int64_t n_kv = ubatch->n_tokens;
296301
const int64_t n_tokens = ubatch->n_tokens;
297302

298-
GGML_ASSERT(kq_mask);
299-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
303+
{
304+
GGML_ASSERT(self_kq_mask);
305+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
300306

301-
float * data = (float *) kq_mask->data;
307+
float * data = (float *) self_kq_mask->data;
302308

303-
// [TAG_NO_CACHE_ISWA]
304-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
309+
for (int h = 0; h < 1; ++h) {
310+
for (int i1 = 0; i1 < n_tokens; ++i1) {
311+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
305312

306-
for (int h = 0; h < 1; ++h) {
307-
for (int i1 = 0; i1 < n_tokens; ++i1) {
308-
const llama_seq_id s1 = ubatch->seq_id[i1][0];
313+
for (int i0 = 0; i0 < n_tokens; ++i0) {
314+
float f = -INFINITY;
309315

310-
for (int i0 = 0; i0 < n_tokens; ++i0) {
311-
float f = -INFINITY;
316+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
317+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
312318

313-
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
314-
const llama_seq_id s0 = ubatch->seq_id[i0][0];
319+
if (s0 != s1) {
320+
continue; // skip different sequences
321+
}
315322

316-
if (s0 != s1) {
317-
continue; // skip different sequences
318-
}
323+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
324+
continue; // skip future tokens for causal attention
325+
}
319326

320-
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321-
continue; // skip future tokens for causal attention
327+
// TODO: reimplement this like in llama_kv_cache_unified
328+
if (hparams.use_alibi) {
329+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
330+
} else {
331+
f = 0.0f;
332+
}
322333
}
334+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
335+
}
336+
}
337+
}
338+
if (debug) {
339+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
340+
}
341+
}
323342

324-
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325-
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326-
// continue; // skip masked tokens for SWA
327-
//}
343+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
344+
GGML_ASSERT(self_kq_mask_swa);
345+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
328346

329-
// TODO: reimplement this like in llama_kv_cache_unified
330-
if (hparams.use_alibi) {
331-
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332-
} else {
333-
f = 0.0f;
347+
float * data = (float *) self_kq_mask_swa->data;
348+
349+
for (int h = 0; h < 1; ++h) {
350+
for (int i1 = 0; i1 < n_tokens; ++i1) {
351+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
352+
353+
for (int i0 = 0; i0 < n_tokens; ++i0) {
354+
float f = -INFINITY;
355+
356+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
357+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
358+
359+
if (s0 != s1) {
360+
continue; // skip different sequences
361+
}
362+
363+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
364+
continue; // skip future tokens for causal attention
365+
}
366+
367+
if (llama_hparams::is_masked_swa(hparams.n_swa, hparams.swa_type, ubatch->pos[i0], ubatch->pos[i1])) {
368+
continue; // skip masked tokens for SWA
369+
}
370+
371+
// TODO: reimplement this like in llama_kv_cache_unified
372+
if (hparams.use_alibi) {
373+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
374+
} else {
375+
f = 0.0f;
376+
}
334377
}
378+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
335379
}
336-
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337380
}
338381
}
339-
}
340-
if (debug) {
341-
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
382+
if (debug) {
383+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
384+
}
342385
}
343386
}
344387

@@ -1299,12 +1342,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991342
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
13001343
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
13011344

1302-
const auto n_kv = k->ne[1];
1303-
13041345
ggml_tensor * cur;
13051346

13061347
// TODO: replace hardcoded padding with ggml-provided padding
1307-
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1348+
if (cparams.flash_attn && kq_b == nullptr) {
13081349
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
13091350

13101351
if (v_trans) {
@@ -1419,10 +1460,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191460
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201461

14211462
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422-
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423-
ggml_set_input(inp->kq_mask);
1463+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1464+
ggml_set_input(inp->self_kq_mask);
14241465

1425-
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1466+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1467+
1468+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1469+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1470+
ggml_set_input(inp->self_kq_mask_swa);
1471+
1472+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1473+
} else {
1474+
inp->self_kq_mask_swa = nullptr;
1475+
inp->self_kq_mask_swa_cnv = nullptr;
1476+
}
14261477

14271478
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
14281479
}
@@ -1447,7 +1498,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471498
ggml_build_forward_expand(gf, k_cur);
14481499
ggml_build_forward_expand(gf, v_cur);
14491500

1450-
const auto & kq_mask = inp->get_kq_mask();
1501+
const bool is_swa = hparams.is_swa(il);
1502+
1503+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14511504

14521505
// [TAG_NO_CACHE_PAD]
14531506
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

src/llama-graph.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,14 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
257257

258258
void set_input(const llama_ubatch * ubatch) override;
259259

260-
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
260+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261+
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261262

262-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
263+
// n_tokens == n_batch
264+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
264268

265269
const llama_hparams hparams;
266270
const llama_cparams cparams;

src/llama-model.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
1137611376
// inp_pos - contains the positions
1137711377
ggml_tensor * inp_pos = build_inp_pos();
1137811378

11379-
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11380-
auto * inp_attn = build_attn_inp_kv_iswa();
11379+
auto * inp_attn = build_attn_inp_no_cache();
1138111380

1138211381
ggml_tensor * inp_out_ids = build_inp_out_ids();
1138311382

@@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1937819377
case LLM_ARCH_NOMIC_BERT_MOE:
1937919378
case LLM_ARCH_NEO_BERT:
1938019379
case LLM_ARCH_WAVTOKENIZER_DEC:
19381-
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
19380+
case LLM_ARCH_GEMMA_EMBEDDING:
1938219381
case LLM_ARCH_DREAM:
1938319382
case LLM_ARCH_LLADA:
1938419383
case LLM_ARCH_LLADA_MOE:

src/llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
312312
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
313313
return nullptr;
314314
}
315+
splits.reserve(n_paths);
315316
for (size_t i = 0; i < n_paths; ++i) {
316317
splits.push_back(paths[i]);
317318
}

0 commit comments

Comments
 (0)