@@ -248,31 +248,70 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
248248 const int64_t n_kv = ubatch->n_tokens ;
249249 const int64_t n_tokens = ubatch->n_tokens ;
250250
251- GGML_ASSERT (kq_mask);
252- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
253-
254- float * data = (float *) kq_mask->data ;
255-
256- for (int h = 0 ; h < 1 ; ++h) {
257- for (int i1 = 0 ; i1 < n_tokens; ++i1) {
258- const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
259-
260- for (int i0 = 0 ; i0 < n_tokens; ++i0) {
261- float f = -INFINITY;
262-
263- for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
264- const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
265-
266- // TODO: reimplement this like in llama_kv_cache_unified
267- if (s0 == s1 && (!cparams.causal_attn || ubatch->pos [i0] <= ubatch->pos [i1])) {
268- if (hparams.use_alibi ) {
269- f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
270- } else {
271- f = 0 .0f ;
251+ GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
252+ float * data = (float *) kq_mask->data ;
253+
254+ for (int h = 0 ; h < 1 ; ++h) {
255+ for (int s1 = 0 ; s1 < n_seqs; ++s1) {
256+ const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
257+
258+ for (int j = 0 ; j < n_seq_tokens; ++j) {
259+ const int32_t tj = s1*n_seq_tokens + j;
260+
261+ for (int s0 = 0 ; s0 < n_seqs; ++s0) {
262+ for (int i = 0 ; i < n_seq_tokens; ++i) {
263+ const int32_t ti = s0*n_seq_tokens + i;
264+ float f = -INFINITY;
265+
266+ for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
267+ if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
268+ if (hparams.use_alibi ) {
269+ f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
270+ } else {
271+ f = 0 .0f ;
272+ }
273+ break ;
274+ }
275+ }
276+
277+ data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
278+ }
272279 }
273- break ;
274280 }
275281 }
282+ }
283+ } else {
284+ const int64_t n_tokens = ubatch->n_tokens ;
285+ const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
286+ const int64_t n_seqs = ubatch->n_seqs ;
287+ const int64_t n_stride = ubatch->n_tokens ;
288+
289+ GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
290+
291+ float * data = (float *) kq_mask->data ;
292+
293+ for (int h = 0 ; h < 1 ; ++h) {
294+ for (int s1 = 0 ; s1 < n_seqs; ++s1) {
295+ const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
296+
297+ for (int j = 0 ; j < n_seq_tokens; ++j) {
298+ const int32_t tj = s1*n_seq_tokens + j;
299+
300+ for (int s0 = 0 ; s0 < n_seqs; ++s0) {
301+ for (int i = 0 ; i < n_seq_tokens; ++i) {
302+ const int32_t ti = s0*n_seq_tokens + i;
303+ float f = -INFINITY;
304+
305+ for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
306+ if (ubatch->seq_id [s0][s] == seq_id) {
307+ if (hparams.use_alibi ) {
308+ f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
309+ } else {
310+ f = 0 .0f ;
311+ }
312+ break ;
313+ }
314+ }
276315
277316 data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
278317 }
@@ -600,24 +639,23 @@ ggml_tensor * llm_graph_context::build_ffn(
600639 } break ;
601640 case LLM_FFN_SWIGLU:
602641 {
603- cur = ggml_swiglu (ctx0, cur);
604- cb (cur, " ffn_swiglu" , il);
605- } break ;
606- case LLM_FFN_GEGLU:
607- {
608- cur = ggml_geglu (ctx0, cur);
609- cb (cur, " ffn_geglu" , il);
610- } break ;
611- case LLM_FFN_REGLU:
612- {
613- cur = ggml_reglu (ctx0, cur);
614- cb (cur, " ffn_reglu" , il);
642+ // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
643+ int64_t split_point = cur->ne [0 ] / 2 ;
644+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
645+ ggml_tensor * x0 = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, split_point, cur->ne [1 ], cur->nb [1 ], 0 ));
646+ ggml_tensor * x1 = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, split_point, cur->ne [1 ], cur->nb [1 ], split_point * ggml_element_size (cur)));
647+
648+ x0 = ggml_silu (ctx0, x0);
649+ cb (cur, " ffn_silu" , il);
650+
651+ cur = ggml_mul (ctx0, x0, x1);
652+ cb (cur, " ffn_mul" , il);
615653 } break ;
616654 case LLM_FFN_GEGLU:
617655 {
618656 // Split into two equal parts
619657 int64_t split_point = cur->ne [0 ] / 2 ;
620- // TODO: these conts should not be needed
658+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
621659 ggml_tensor * x0 = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, split_point, cur->ne [1 ], cur->nb [1 ], 0 ));
622660 ggml_tensor * x1 = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, split_point, cur->ne [1 ], cur->nb [1 ], split_point * ggml_element_size (cur)));
623661
@@ -1300,15 +1338,12 @@ ggml_tensor * llm_graph_context::build_attn(
13001338
13011339 const bool is_swa = hparams.is_swa (il);
13021340
1303- const auto * mctx_cur = is_swa ? mctx_iswa->get_swa () : mctx_iswa->get_base ();
1304-
1305- // optionally store to KV cache
1306- if (k_cur) {
1307- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1308- }
1341+ const auto * kv_state = is_swa ? kv_state_iswa->get_swa () : kv_state_iswa->get_base ();
13091342
1310- if (v_cur) {
1311- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1343+ // store to KV cache
1344+ {
1345+ ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, il));
1346+ ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, il));
13121347 }
13131348
13141349 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1390,121 +1425,30 @@ ggml_tensor * llm_graph_context::build_attn(
13901425 return cur;
13911426}
13921427
1393- ggml_tensor * llm_graph_context::build_attn (
1394- llm_graph_input_mem_hybrid * inp,
1395- ggml_cgraph * gf,
1396- ggml_tensor * wo,
1397- ggml_tensor * wo_b,
1398- ggml_tensor * q_cur,
1399- ggml_tensor * k_cur,
1400- ggml_tensor * v_cur,
1401- ggml_tensor * kq_b,
1402- ggml_tensor * v_mla,
1403- float kq_scale,
1404- int il) const {
1405- // these nodes are added to the graph together so that they are not reordered
1406- // by doing so, the number of splits in the graph is reduced
1407- ggml_build_forward_expand (gf, q_cur);
1408- ggml_build_forward_expand (gf, k_cur);
1409- ggml_build_forward_expand (gf, v_cur);
1410-
1411- const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx)->get_attn ();
1412-
1413- // store to KV cache
1414- {
1415- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1416- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1417- }
1418-
1419- const auto & kq_mask = inp->get_kq_mask ();
1420-
1421- ggml_tensor * q = q_cur;
1422- ggml_tensor * k = mctx_cur->get_k (ctx0, il);
1423- ggml_tensor * v = mctx_cur->get_v (ctx0, il);
1424-
1425- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1426- cb (cur, " kqv_out" , il);
1427-
1428- if (wo) {
1429- cur = build_lora_mm (wo, cur);
1430- if (arch == LLM_ARCH_GLM4) {
1431- // GLM4 seems to have numerical issues with half-precision accumulators
1432- ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1433- }
1434- }
1435-
1436- if (wo_b) {
1437- cur = ggml_add (ctx0, cur, wo_b);
1438- }
1439-
1440- return cur;
1441- }
1442-
1443- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1444- const auto * mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
1445-
1446- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1447-
1448- {
1449- const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1450-
1451- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1452- // cb(inp->self_kq_mask, "KQ_mask", -1);
1453- ggml_set_input (inp->self_kq_mask );
1454-
1455- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1456- }
1457-
1458- {
1459- GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified for non-SWA" );
1460-
1461- const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1462-
1463- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1464- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1465- ggml_set_input (inp->self_kq_mask_swa );
1466-
1467- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1468- }
1469-
1470- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input (std::move (inp));
1471- }
1472-
1473- ggml_tensor * llm_graph_context::build_rs (
1474- ggml_cgraph * gf,
1475- ggml_tensor * s,
1476- ggml_tensor * state_copy,
1477- int32_t state_size,
1478- int32_t n_seqs,
1479- uint32_t n_kv,
1480- uint32_t kv_head,
1481- uint32_t kv_size,
1482- int32_t rs_zero,
1483- bool avoid_copies) const {
1428+ ggml_tensor * llm_graph_context::build_copy_mask_state (
1429+ ggml_cgraph * gf,
1430+ ggml_tensor * s,
1431+ ggml_tensor * state_copy,
1432+ ggml_tensor * state_mask,
1433+ int32_t n_state,
1434+ int32_t n_seqs) const {
1435+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
14841436
1485- ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_size);
1437+ const auto n_kv = kv_state->get_n_kv ();
1438+ const auto kv_head = kv_state->get_head ();
14861439
1487- // Clear a single state which will then be copied to the other cleared states.
1488- // Note that this is a no-op when the view is zero-sized.
1489- ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1490- ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1440+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, n_state, kv_state->get_size ());
14911441
1492- ggml_tensor * output_states;
1442+ // copy states
1443+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1444+ // this shrinks the tensors's ne[1] to n_kv
1445+ states = ggml_get_rows (ctx0, states, state_copy);
14931446
1494- if (!avoid_copies) {
1495- // copy states
1496- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1497- // {state_size, kv_size} -> {state_size, n_seqs}
1498- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_seqs, 0 ));
1499- ggml_build_forward_expand (gf, output_states);
1500- } else {
1501- // FIXME: make the gathering operation happen before the copy below
1502- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1503- output_states = states;
1504- }
1447+ // clear states of sequences which are starting at the beginning of this batch
1448+ // FIXME: zero-out NANs?
1449+ states = ggml_mul (ctx0, states, state_mask);
15051450
1506- // copy extra states which won't be changed further (between n_seqs and n_kv)
1507- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb [0 ]));
1451+ // copy states which won't be changed further (between n_seqs and n_kv)
15081452 ggml_build_forward_expand (gf,
15091453 ggml_cpy (ctx0,
15101454 states_extra,
0 commit comments