@@ -901,17 +901,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
901901 const int64_t n_embd = hparams.n_embd ;
902902
903903 // when computing embeddings, all tokens are output
904- const bool embd_all = cparams.embeddings ;
904+ const bool output_all = cparams.embeddings ;
905905
906- if (!batch_allocr->init (batch_inp, vocab, memory.get (), n_embd, embd_all )) {
906+ if (!batch_allocr->init (batch_inp, vocab, memory.get (), n_embd, output_all )) {
907907 LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
908908 return -1 ;
909909 }
910910
911911 const uint32_t n_tokens_all = batch_allocr->get_n_tokens ();
912912 const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
913913
914- if (embd_all ) {
914+ if (output_all ) {
915915 // require that all tokens are output
916916 if (n_outputs_all != n_tokens_all) {
917917 LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " ,
@@ -940,7 +940,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
940940 llama_memory_state_ptr mstate;
941941
942942 while (true ) {
943- mstate = memory->init_batch (batch_allocr.get (), cparams.n_ubatch , embd_all );
943+ mstate = memory->init_batch (batch_allocr.get (), cparams.n_ubatch , output_all );
944944 if (!mstate) {
945945 return -2 ;
946946 }
0 commit comments