@@ -548,7 +548,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
548548 if (cells.is_empty (i)) {
549549 ss += ' .' ;
550550 } else {
551- ss += ' x ' ;
551+ ss += std::to_string (cells. seq_get (i)) ;
552552 }
553553 if (i%256 == 255 ) {
554554 ss += ' \n ' ;
@@ -557,6 +557,10 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
557557 }
558558 LLAMA_LOG_WARN (" \n %s\n " , ss.c_str ());
559559 }
560+
561+ LLAMA_LOG_WARN (" kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n " , n_swa, cells.seq_pos_min (0 ), cells.seq_pos_max (0 ));
562+ LLAMA_LOG_WARN (" kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n " , n_swa, cells.seq_pos_min (1 ), cells.seq_pos_max (1 ));
563+ LLAMA_LOG_WARN (" kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n " , n_swa, cells.seq_pos_min (2 ), cells.seq_pos_max (2 ));
560564#endif
561565
562566 uint32_t n_tested = 0 ;
@@ -568,24 +572,44 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
568572 continue ;
569573 }
570574
575+ // keep track of what the minimum sequence positions would be if we accept the ubatch
576+ llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
577+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
578+ seq_pos_min[s] = cells.seq_pos_min (s);
579+ }
580+
571581 bool found = true ;
572582 for (uint32_t i = 0 ; i < n_tokens; i++) {
573583 const llama_pos pos = ubatch.pos [i];
574584 const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
575585
576586 // can we use this cell? either:
577587 // - the cell is empty
578- // - the cell is occupied only by the same sequence, and the pos is masked
579- const bool can_use =
580- cells.is_empty (head_cur + i) ||
581- (
582- cells.seq_has (head_cur + i, seq_id) && // sequence mask
583- cells.seq_count (head_cur + i) == 1 &&
584- (
585- cells.pos_get (head_cur + i) >= pos || // causal mask
586- is_masked_swa (cells.pos_get (head_cur + i), ubatch.seq_pos_min [seq_id]) // SWA mask
587- )
588- );
588+ // - the cell is occupied only by one sequence:
589+ // - mask causally, if the sequence is the same as the one we are inserting
590+ // - mask SWA, using current max pos for that sequence in the cache
591+ // always insert in the cell with minimum pos
592+ bool can_use = cells.is_empty (head_cur + i);
593+
594+ if (!can_use && cells.seq_count (head_cur + i) == 1 ) {
595+ const llama_pos pos_cell = cells.pos_get (head_cur + i);
596+
597+ // causal mask
598+ if (cells.seq_has (head_cur + i, seq_id)) {
599+ can_use = pos_cell >= pos;
600+ }
601+
602+ if (!can_use) {
603+ const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i);
604+
605+ // SWA mask
606+ if (pos_cell == seq_pos_min[seq_id_cell] &&
607+ is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
608+ seq_pos_min[seq_id_cell]++;
609+ can_use = true ;
610+ }
611+ }
612+ }
589613
590614 if (!can_use) {
591615 found = false ;
@@ -613,9 +637,7 @@ void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & u
613637
614638 for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
615639 if (!cells.is_empty (head + i)) {
616- cells.pos_chg (head + i, ubatch.pos [i]);
617-
618- continue ;
640+ cells.rm (head + i);
619641 }
620642
621643 cells.pos_set (head + i, ubatch.pos [i]);
0 commit comments