@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261 }
262262}
263263
264- static void print_mask (float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+ static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265 LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n " , __func__);
266- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? " LLAMA_SWA_TYPE_NONE" :
267- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? " LLAMA_SWA_TYPE_STANDARD" :
268- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? " LLAMA_SWA_TYPE_CHUNKED" :
269- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? " LLAMA_SWA_TYPE_SYMMETRIC" : " unknown" ;
266+ const char * swa_type_str = " unknown" ;
267+
268+ switch (swa_type) {
269+ case LLAMA_SWA_TYPE_NONE: swa_type_str = " LLAMA_SWA_TYPE_NONE" ; break ;
270+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = " LLAMA_SWA_TYPE_STANDARD" ; break ;
271+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = " LLAMA_SWA_TYPE_CHUNKED" ; break ;
272+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = " LLAMA_SWA_TYPE_SYMMETRIC" ; break ;
273+ };
274+
270275 LLAMA_LOG_DEBUG (" %s: n_swa : %d, n_kv: %d, swq_type: %s\n " , __func__, (int )n_swa, (int )n_kv, swa_type_str);
271276 LLAMA_LOG_DEBUG (" %s: '0' = can attend, '∞' = masked\n " , __func__);
272277 LLAMA_LOG_DEBUG (" %s: Rows = query tokens, Columns = key/value tokens\n\n " , __func__);
@@ -295,50 +300,88 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300 const int64_t n_kv = ubatch->n_tokens ;
296301 const int64_t n_tokens = ubatch->n_tokens ;
297302
298- GGML_ASSERT (kq_mask);
299- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
303+ {
304+ GGML_ASSERT (self_kq_mask);
305+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
300306
301- float * data = (float *) kq_mask ->data ;
307+ float * data = (float *) self_kq_mask ->data ;
302308
303- // [TAG_NO_CACHE_ISWA]
304- GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " TODO: implement" );
309+ for (int h = 0 ; h < 1 ; ++h) {
310+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
311+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
305312
306- for (int h = 0 ; h < 1 ; ++h) {
307- for (int i1 = 0 ; i1 < n_tokens; ++i1) {
308- const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
313+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
314+ float f = -INFINITY;
309315
310- for (int i0 = 0 ; i0 < n_tokens ; ++i0 ) {
311- float f = -INFINITY ;
316+ for (int s = 0 ; s < ubatch-> n_seq_id [i0] ; ++s ) {
317+ const llama_seq_id s0 = ubatch-> seq_id [i0][ 0 ] ;
312318
313- for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
314- const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
319+ if (s0 != s1) {
320+ continue ; // skip different sequences
321+ }
315322
316- if (s0 != s1 ) {
317- continue ; // skip different sequences
318- }
323+ if (cparams. causal_attn && ubatch-> pos [i0] > ubatch-> pos [i1] ) {
324+ continue ; // skip future tokens for causal attention
325+ }
319326
320- if (cparams.causal_attn && ubatch->pos [i0] > ubatch->pos [i1]) {
321- continue ; // skip future tokens for causal attention
327+ // TODO: reimplement this like in llama_kv_cache_unified
328+ if (hparams.use_alibi ) {
329+ f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
330+ } else {
331+ f = 0 .0f ;
332+ }
322333 }
334+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
335+ }
336+ }
337+ }
338+ if (debug) {
339+ print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
340+ }
341+ }
323342
324- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325- // if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326- // continue; // skip masked tokens for SWA
327- // }
343+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
344+ GGML_ASSERT (self_kq_mask_swa);
345+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
328346
329- // TODO: reimplement this like in llama_kv_cache_unified
330- if (hparams.use_alibi ) {
331- f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
332- } else {
333- f = 0 .0f ;
347+ float * data = (float *) self_kq_mask_swa->data ;
348+
349+ for (int h = 0 ; h < 1 ; ++h) {
350+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
351+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
352+
353+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
354+ float f = -INFINITY;
355+
356+ for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
357+ const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
358+
359+ if (s0 != s1) {
360+ continue ; // skip different sequences
361+ }
362+
363+ if (cparams.causal_attn && ubatch->pos [i0] > ubatch->pos [i1]) {
364+ continue ; // skip future tokens for causal attention
365+ }
366+
367+ if (llama_hparams::is_masked_swa (hparams.n_swa , hparams.swa_type , ubatch->pos [i0], ubatch->pos [i1])) {
368+ continue ; // skip masked tokens for SWA
369+ }
370+
371+ // TODO: reimplement this like in llama_kv_cache_unified
372+ if (hparams.use_alibi ) {
373+ f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
374+ } else {
375+ f = 0 .0f ;
376+ }
334377 }
378+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
335379 }
336- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337380 }
338381 }
339- }
340- if (debug) {
341- print_mask (data, n_tokens, n_kv, hparams. n_swa , hparams. swa_type );
382+ if (debug) {
383+ print_mask (data, n_tokens, n_kv, hparams. n_swa , hparams. swa_type );
384+ }
342385 }
343386}
344387
@@ -1299,12 +1342,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991342 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
13001343 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
13011344
1302- const auto n_kv = k->ne [1 ];
1303-
13041345 ggml_tensor * cur;
13051346
13061347 // TODO: replace hardcoded padding with ggml-provided padding
1307- if (cparams.flash_attn && (n_kv % 256 == 0 ) && kq_b == nullptr ) {
1348+ if (cparams.flash_attn && kq_b == nullptr ) {
13081349 GGML_ASSERT (kq_b == nullptr && " Flash attention does not support KQ bias yet" );
13091350
13101351 if (v_trans) {
@@ -1419,10 +1460,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191460 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201461
14211462 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422- inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1423- ggml_set_input (inp->kq_mask );
1463+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1464+ ggml_set_input (inp->self_kq_mask );
14241465
1425- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1466+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1467+
1468+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1469+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1470+ ggml_set_input (inp->self_kq_mask_swa );
1471+
1472+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1473+ } else {
1474+ inp->self_kq_mask_swa = nullptr ;
1475+ inp->self_kq_mask_swa_cnv = nullptr ;
1476+ }
14261477
14271478 return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
14281479}
@@ -1447,7 +1498,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471498 ggml_build_forward_expand (gf, k_cur);
14481499 ggml_build_forward_expand (gf, v_cur);
14491500
1450- const auto & kq_mask = inp->get_kq_mask ();
1501+ const bool is_swa = hparams.is_swa (il);
1502+
1503+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
14511504
14521505 // [TAG_NO_CACHE_PAD]
14531506 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
0 commit comments