@@ -2337,17 +2337,17 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
23372337void llama_kv_cache_recurrent::set_full () {
23382338 n = size;
23392339 head = 0 ;
2340+ rs_z = 0 ;
23402341}
23412342
23422343bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
2343- const uint32_t n_tokens = ubatch.n_tokens ;
2344- const uint32_t n_seqs = ubatch.n_seqs ;
2344+ const uint32_t n_seqs = ubatch.n_seqs ;
23452345
23462346 const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
23472347
23482348 // if we have enough unused cells before the current head ->
23492349 // better to start searching from the beginning of the cache, hoping to fill it
2350- if (head > used + 2 *n_tokens ) {
2350+ if (head > used + 2 *n_seqs ) {
23512351 head = 0 ;
23522352 }
23532353
@@ -2443,16 +2443,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
24432443 empty_cell.src = orig_cell.src ;
24442444 orig_cell.seq_id .erase (seq_id);
24452445 empty_cell.seq_id .insert (seq_id); // will be overwritten
2446+ GGML_ASSERT (!orig_cell.is_empty ()); // has at least one remaining seq_id
24462447 }
24472448 seq_meta.tail = next_empty_cell;
24482449 // find next empty cell
24492450 if (s + 1 < n_seqs) {
2450- next_empty_cell += 1 ;
24512451 for (uint32_t i = 0 ; i < size; ++i) {
2452+ next_empty_cell += 1 ;
24522453 if (next_empty_cell >= size) { next_empty_cell -= size; }
24532454 kv_cell & cell = cells[next_empty_cell];
24542455 if (cell.is_empty ()) { break ; }
2455- next_empty_cell += 1 ;
24562456 }
24572457 }
24582458 }
@@ -2472,12 +2472,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
24722472 std::swap (dst_cell.src , src_cell.src );
24732473 std::swap (dst_cell.seq_id , src_cell.seq_id );
24742474
2475- // swap tails (assuming they NEVER overlap)
2476- for (const llama_seq_id seq_id : src_cell.seq_id ) {
2477- cells[seq_id].tail = src_id;
2478- }
2479- for (const llama_seq_id seq_id : dst_cell.seq_id ) {
2480- cells[seq_id].tail = dst_id;
2475+ // swap tails
2476+ for (uint32_t i = 0 ; i < size; ++i) {
2477+ int32_t & tail = cells[i].tail ;
2478+ if (tail == src_id) {
2479+ tail = dst_id;
2480+ } else if (tail == dst_id) {
2481+ tail = src_id;
2482+ }
24812483 }
24822484 }
24832485 }
@@ -2506,13 +2508,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25062508 // Find first to-be-cleared cell
25072509 rs_z = -1 ;
25082510 for (int i = min; i <= max; ++i) {
2509- if (rs_z < 0 && cells[i].src == -1 ) {
2510- rs_z = i;
2511+ if (cells[i].src == -1 ) {
2512+ if (rs_z < 0 ) {
2513+ rs_z = i;
2514+ }
2515+
2516+ cells[i].src0 = rs_z;
2517+ } else {
2518+ // Stage the source ids for all used cells to allow correct seq_* behavior
2519+ // and still make these values available when setting the inputs
2520+ cells[i].src0 = cells[i].src ;
25112521 }
2512- // Stage the source ids for all used cells to allow correct seq_* behavior
2513- // and still make these values available when setting the inputs
2514- cells[i].src0 = cells[i].src ;
2515- cells[i].src = i;
2522+ cells[i].src = i; // avoid moving or clearing twice
25162523 }
25172524
25182525 // allow getting the range of used cells, from head to head + n
0 commit comments