@@ -18,6 +18,8 @@ llama_batch_allocr::llama_batch_allocr() {
1818 for (auto & cur : seq_cpl) {
1919 cur.resize (LLAMA_MAX_SEQ);
2020 }
21+
22+ seq_idx.resize (LLAMA_MAX_SEQ, -1 );
2123}
2224
2325bool llama_batch_allocr::init (
@@ -137,22 +139,23 @@ bool llama_batch_allocr::init(
137139 // compute stats
138140 //
139141
142+ this ->n_embd = n_embd;
143+
140144 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
141145 n_outputs += batch.logits [i] != 0 ;
142146 }
143147
144- this ->n_embd = n_embd;
145-
146148 // determine coupled sequences
147149 // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
148150 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
151+ const llama_seq_id s0 = batch.seq_id [i][0 ];
152+
149153 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
150- seq_pos[ batch.seq_id [i][s]]. insert (batch. pos [i]) ;
154+ const llama_seq_id s1 = batch.seq_id [i][s];
151155
152- if (s > 0 ) {
153- const llama_seq_id s0 = batch.seq_id [i][0 ];
154- const llama_seq_id s1 = batch.seq_id [i][s];
156+ seq_pos[s1].insert (batch.pos [i]);
155157
158+ if (s > 0 ) {
156159 // mark that sequence s1 is coupled to s0
157160 seq_cpl[s1][s0] = true ;
158161
@@ -162,14 +165,28 @@ bool llama_batch_allocr::init(
162165 }
163166 }
164167
165- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
166- seq_set_t cur;
167- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
168- cur.set (batch.seq_id [i][s]);
168+ {
169+ seq_set_t seq_set_unq;
170+
171+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
172+ seq_set_t cur;
173+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
174+ const llama_seq_id s0 = batch.seq_id [i][s];
175+
176+ cur.set (s0);
177+ seq_set_unq.set (s0);
178+ }
179+
180+ seq_set.push_back (cur);
181+ seq_set_map[cur].push_back (i);
169182 }
170183
171- seq_set.push_back (cur);
172- seq_set_map[cur].push_back (i);
184+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
185+ if (seq_set_unq.test (s)) {
186+ seq_idx[s] = seq_id_unq.size ();
187+ seq_id_unq.push_back (s);
188+ }
189+ }
173190 }
174191
175192 if (debug > 0 ) {
@@ -180,11 +197,14 @@ bool llama_batch_allocr::init(
180197 /* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
181198 /* .n_seq_tokens =*/ (uint32_t ) 1 ,
182199 /* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
200+ /* .n_seqs_unq =*/ (uint32_t ) this ->seq_id_unq .size (),
183201 /* .token =*/ batch.token ,
184202 /* .embd =*/ batch.embd ,
185203 /* .pos =*/ batch.pos ,
186204 /* .n_seq_id =*/ batch.n_seq_id ,
187205 /* .seq_id =*/ batch.seq_id ,
206+ /* .seq_id_unq =*/ this ->seq_id_unq .data (),
207+ /* .seq_idx =*/ this ->seq_idx .data (),
188208 /* .output =*/ batch.logits ,
189209 };
190210
@@ -270,32 +290,44 @@ bool llama_batch_allocr::init(
270290 return true ;
271291}
272292
273- llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_tokens) {
293+ llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs) {
294+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
295+
274296 clear ();
275297 split_reset ();
276298
277299 ubatches.emplace_back ();
278300
279301 auto & ubatch = ubatches.back ();
280302
281- ubatch.token .resize (n_tokens);
282- ubatch.embd .clear ();
283- ubatch.pos .resize (n_tokens);
284- ubatch.n_seq_id .resize (n_tokens);
285- ubatch.seq_id .resize (n_tokens);
286- ubatch.output .resize (n_tokens);
303+ ubatch.token .resize (n_tokens);
304+ ubatch.embd .clear ();
305+ ubatch.pos .resize (n_tokens);
306+ ubatch.n_seq_id .resize (n_tokens);
307+ ubatch.seq_id .resize (n_tokens);
308+ ubatch.seq_id_unq .resize (0 );
309+ ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
310+ ubatch.output .resize (n_tokens);
311+
312+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
313+ ubatch.seq_idx [s] = s;
314+ ubatch.seq_id_unq .push_back (s);
315+ }
287316
288317 llama_ubatch res {
289318 /* .equal_seqs =*/ true ,
290319 /* .n_tokens =*/ n_tokens,
291- /* .n_seq_tokens =*/ n_tokens,
292- /* .n_seqs =*/ 1 ,
320+ /* .n_seq_tokens =*/ n_seq_tokens,
321+ /* .n_seqs =*/ n_seqs,
322+ /* .n_seqs_unq =*/ n_seqs,
293323
294324 /* .token =*/ ubatch.token .data (),
295325 /* .embd =*/ nullptr ,
296326 /* .pos =*/ ubatch.pos .data (),
297327 /* .n_seq_id =*/ ubatch.n_seq_id .data (),
298328 /* .seq_id =*/ ubatch.seq_id .data (),
329+ /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
330+ /* .seq_idx =*/ ubatch.seq_idx .data (),
299331 /* .output =*/ ubatch.output .data (),
300332 };
301333
@@ -489,10 +521,11 @@ void llama_batch_allocr::clear() {
489521
490522 batch = {};
491523
492- pos .clear ();
493- n_seq_id.clear ();
494- seq_id .clear ();
495- output .clear ();
524+ pos .clear ();
525+ n_seq_id .clear ();
526+ seq_id .clear ();
527+ seq_id_unq.clear ();
528+ output .clear ();
496529
497530 for (auto & cur : seq_pos) {
498531 cur.clear ();
@@ -516,12 +549,16 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
516549
517550 auto & ubatch = ubatches.back ();
518551
519- ubatch.token .resize (n_tokens);
520- ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
521- ubatch.pos .resize (n_tokens);
522- ubatch.n_seq_id .resize (n_tokens);
523- ubatch.seq_id .resize (n_tokens);
524- ubatch.output .resize (n_tokens);
552+ ubatch.token .resize (n_tokens);
553+ ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
554+ ubatch.pos .resize (n_tokens);
555+ ubatch.n_seq_id .resize (n_tokens);
556+ ubatch.seq_id .resize (n_tokens);
557+ ubatch.seq_id_unq .resize (0 );
558+ ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
559+ ubatch.output .resize (n_tokens);
560+
561+ seq_set_t seq_set_unq;
525562
526563 for (size_t i = 0 ; i < idxs.size (); ++i) {
527564 if (batch.token ) {
@@ -537,22 +574,36 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
537574 ubatch.seq_id [i] = batch.seq_id [idxs[i]];
538575 ubatch.output [i] = batch.logits [idxs[i]];
539576
577+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
578+ seq_set_unq.set (ubatch.seq_id [i][s]);
579+ }
580+
540581 if (ubatch.output [i]) {
541582 out_ids.push_back (idxs[i]);
542583 }
543584 }
544585
586+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
587+ if (seq_set_unq.test (s)) {
588+ ubatch.seq_idx [s] = ubatch.seq_id_unq .size ();
589+ ubatch.seq_id_unq .push_back (s);
590+ }
591+ }
592+
545593 llama_ubatch res {
546594 /* .equal_seqs =*/ equal_seqs,
547595 /* .n_tokens =*/ n_tokens,
548596 /* .n_seq_tokens =*/ n_tokens/n_seqs,
549597 /* .n_seqs =*/ n_seqs,
598+ /* .n_seqs_unq =*/ (uint32_t ) ubatch.seq_id_unq .size (),
550599
551600 /* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
552601 /* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
553602 /* .pos =*/ ubatch.pos .data (),
554603 /* .n_seq_id =*/ ubatch.n_seq_id .data (),
555604 /* .seq_id =*/ ubatch.seq_id .data (),
605+ /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
606+ /* .seq_idx =*/ ubatch.seq_idx .data (),
556607 /* .output =*/ ubatch.output .data (),
557608 };
558609
@@ -571,14 +622,38 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
571622 LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
572623 LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
573624 LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
625+ LLAMA_LOG_DEBUG (" %s: n_seqs_unq = %d\n " , __func__, ubatch.n_seqs_unq );
626+
627+ std::stringstream ss_seq_id_unq;
628+ std::stringstream ss_seq_idx;
629+
630+ ss_seq_id_unq << " [ " ;
631+ ss_seq_idx << " [" ;
632+
633+ for (uint32_t s = 0 ; s < ubatch.n_seqs_unq ; ++s) {
634+ ss_seq_id_unq << ubatch.seq_id_unq [s] << " " ;
635+ }
636+
637+ for (uint32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
638+ if (ubatch.seq_idx [s] >= 0 ) {
639+ ss_seq_idx << ubatch.seq_idx [s]%10 ;
640+ } else {
641+ ss_seq_idx << " ." ;
642+ }
643+ }
644+
645+ ss_seq_id_unq << " ]" ;
646+ ss_seq_idx << " ]" ;
574647
575- LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
576- LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
577- LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
578- LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
579- LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
580- LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
581- LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
648+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
649+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
650+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
651+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
652+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
653+ LLAMA_LOG_DEBUG (" %s: seq_id_unq = %s\n " , __func__, ss_seq_id_unq.str ().c_str ());
654+ LLAMA_LOG_DEBUG (" %s: seq_idx = %s\n " , __func__, ss_seq_idx.str ().c_str ());
655+ LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
656+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
582657
583658 if (debug > 1 ) {
584659 int seq_id_max = 0 ;
0 commit comments