@@ -571,6 +571,38 @@ bool llama_batch_allocr::init(
571571 return true ;
572572}
573573
574+ llama_ubatch llama_batch_allocr::reserve_one (uint32_t n_tokens) {
575+ clear ();
576+ split_reset ();
577+
578+ ubatches.emplace_back ();
579+
580+ auto & ubatch = ubatches.back ();
581+
582+ ubatch.token .resize (n_tokens);
583+ ubatch.embd .clear ();
584+ ubatch.pos .resize (n_tokens);
585+ ubatch.n_seq_id .resize (n_tokens);
586+ ubatch.seq_id .resize (n_tokens);
587+ ubatch.output .resize (n_tokens);
588+
589+ llama_ubatch res {
590+ /* .equal_seqs =*/ true ,
591+ /* .n_tokens =*/ n_tokens,
592+ /* .n_seq_tokens =*/ n_tokens,
593+ /* .n_seqs =*/ 1 ,
594+
595+ /* .token =*/ ubatch.token .data (),
596+ /* .embd =*/ nullptr ,
597+ /* .pos =*/ ubatch.pos .data (),
598+ /* .n_seq_id =*/ ubatch.n_seq_id .data (),
599+ /* .seq_id =*/ ubatch.seq_id .data (),
600+ /* .output =*/ ubatch.output .data (),
601+ };
602+
603+ return res;
604+ }
605+
574606const llama_batch & llama_batch_allocr::get_batch () const {
575607 return batch;
576608}
@@ -757,10 +789,11 @@ void llama_batch_allocr::clear() {
757789 n_outputs = 0 ;
758790
759791 batch = {};
760- pos.clear ();
792+
793+ pos .clear ();
761794 n_seq_id.clear ();
762- seq_id.clear ();
763- output.clear ();
795+ seq_id .clear ();
796+ output .clear ();
764797
765798 for (auto & cur : seq_pos) {
766799 cur.clear ();
@@ -786,12 +819,12 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
786819
787820 auto & ubatch = ubatches.back ();
788821
789- ubatch.token .resize (n_tokens);
790- ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
791- ubatch.pos .resize (n_tokens);
822+ ubatch.token .resize (n_tokens);
823+ ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
824+ ubatch.pos .resize (n_tokens);
792825 ubatch.n_seq_id .resize (n_tokens);
793- ubatch.seq_id .resize (n_tokens);
794- ubatch.output .resize (n_tokens);
826+ ubatch.seq_id .resize (n_tokens);
827+ ubatch.output .resize (n_tokens);
795828
796829 for (size_t i = 0 ; i < idxs.size (); ++i) {
797830 if (batch.token ) {
@@ -839,25 +872,25 @@ struct llama_batch llama_batch_get_one(
839872 llama_token * tokens,
840873 int32_t n_tokens) {
841874 return {
842- /* n_tokens =*/ n_tokens,
843- /* tokens =*/ tokens,
844- /* embd =*/ nullptr ,
845- /* pos =*/ nullptr ,
846- /* n_seq_id =*/ nullptr ,
847- /* seq_id =*/ nullptr ,
848- /* logits =*/ nullptr ,
875+ /* n_tokens =*/ n_tokens,
876+ /* tokens =*/ tokens,
877+ /* embd =*/ nullptr ,
878+ /* pos =*/ nullptr ,
879+ /* n_seq_id =*/ nullptr ,
880+ /* seq_id =*/ nullptr ,
881+ /* logits =*/ nullptr ,
849882 };
850883}
851884
852885struct llama_batch llama_batch_init (int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
853886 llama_batch batch = {
854- /* n_tokens =*/ 0 ,
855- /* tokens =*/ nullptr ,
856- /* embd =*/ nullptr ,
857- /* pos =*/ nullptr ,
858- /* n_seq_id =*/ nullptr ,
859- /* seq_id =*/ nullptr ,
860- /* logits =*/ nullptr ,
887+ /* n_tokens =*/ 0 ,
888+ /* tokens =*/ nullptr ,
889+ /* embd =*/ nullptr ,
890+ /* pos =*/ nullptr ,
891+ /* n_seq_id =*/ nullptr ,
892+ /* seq_id =*/ nullptr ,
893+ /* logits =*/ nullptr ,
861894 };
862895
863896 if (embd) {
0 commit comments