@@ -219,16 +219,31 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
219219 candidates->size = last_idx;
220220}
221221
222- void apply_penalties (int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
222+ void sample_rep_pen (int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
223223{
224224 auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), rep_pen_range), n_ctx);
225- llama_sample_repetition_penalty (nullptr , & candidates_p,
225+ llama_sample_repetition_penalty (nullptr , candidates_p,
226226 last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
227227 last_n_repeat, rep_pen);
228228}
229229
230+ void sample_temperature (llama_token_data_array * candidates_p, float temp)
231+ {
232+ if (temp <= 0 )
233+ {
234+ // Imitate greedy sampling
235+ temp = 0 .01f ; // cannot be zero else div0
236+ llama_sample_temperature (nullptr , candidates_p, temp);
237+ llama_sample_top_k (nullptr , candidates_p, 1 , 1 ); // only want first candidate
238+ }
239+ else
240+ {
241+ llama_sample_temperature (nullptr , candidates_p, temp);
242+ }
243+ }
244+
230245int SampleLogits (const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
231- int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX] )
246+ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector< samplers> & sampler_order)
232247{
233248 int id = 0 ;
234249 std::vector<llama_token_data> candidates;
@@ -239,78 +254,54 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa
239254
240255 llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
241256
242- // Run this except for when we are going to do the sampler reordering case below
243- if (temp <= 0 || mirostat > 0 || sampler_len == 0 )
244- {
245- apply_penalties (n_ctx, rep_pen_range, rep_pen, candidates_p);
246- }
247-
248- // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
249- // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
250- // last_n_repeat, alpha_frequency, alpha_presence);
251-
252- if (temp <= 0 )
253- {
254- // Greedy sampling
255- id = llama_sample_token_greedy (nullptr , &candidates_p);
256- }
257- else
257+ if (mirostat == 1 || mirostat == 2 )
258258 {
259+ static float mirostat_mu = 2 .0f * mirostat_tau;
260+ const int mirostat_m = 100 ;
261+ sample_rep_pen (n_ctx, rep_pen_range, rep_pen, &candidates_p);
262+ sample_temperature (&candidates_p, temp);
259263 if (mirostat == 1 )
260264 {
261- static float mirostat_mu = 2 .0f * mirostat_tau;
262- const int mirostat_m = 100 ;
263- llama_sample_temperature (nullptr , &candidates_p, temp);
264265 id = sample_token_mirostat (n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
265266 }
266- else if (mirostat == 2 )
267+ else
267268 {
268- static float mirostat_mu = 2 .0f * mirostat_tau;
269- llama_sample_temperature (nullptr , &candidates_p, temp);
270269 id = sample_token_mirostat_v2 (&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
271270 }
272- else if (sampler_len > 0 )
271+ }
272+ else
273+ {
274+ for (int i = 0 ; i < sampler_order.size (); i++)
273275 {
274- for ( int i = 0 ; i < sampler_len; i++) {
275- switch (sampler_order[i]) {
276- case KCPP_SAMPLER_TOP_K:
277- llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
278- break ;
279- case KCPP_SAMPLER_TOP_A:
280- sample_top_a (&candidates_p,top_a,1 );
281- break ;
282- case KCPP_SAMPLER_TOP_P:
283- llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
284- break ;
285- case KCPP_SAMPLER_TFS:
286- llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
287- break ;
288- case KCPP_SAMPLER_TYP:
289- llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
290- break ;
291- case KCPP_SAMPLER_TEMP:
292- llama_sample_temperature ( nullptr , &candidates_p, temp);
293- break ;
294- case KCPP_SAMPLER_REP_PEN:
295- apply_penalties (n_ctx, rep_pen_range, rep_pen, candidates_p);
296- break ;
297- default :
298- break ;
299- }
276+ switch (sampler_order[i])
277+ {
278+ case KCPP_SAMPLER_TOP_K:
279+ llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
280+ break ;
281+ case KCPP_SAMPLER_TOP_A:
282+ sample_top_a (&candidates_p,top_a,1 );
283+ break ;
284+ case KCPP_SAMPLER_TOP_P:
285+ llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
286+ break ;
287+ case KCPP_SAMPLER_TFS:
288+ llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
289+ break ;
290+ case KCPP_SAMPLER_TYP:
291+ llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
292+ break ;
293+ case KCPP_SAMPLER_TEMP:
294+ sample_temperature ( &candidates_p, temp);
295+ break ;
296+ case KCPP_SAMPLER_REP_PEN:
297+ sample_rep_pen (n_ctx, rep_pen_range, rep_pen, & candidates_p);
298+ break ;
299+ default :
300+ printf ( " \n SampleLogits: Unknown Sampler : %d " ,sampler_order[i]) ;
301+ break ;
300302 }
301- id = sample_token (&candidates_p, rng);
302- }
303- else
304- {
305- // Temperature sampling
306- llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
307- sample_top_a (&candidates_p,top_a,1 );
308- llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
309- llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
310- llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
311- llama_sample_temperature (nullptr , &candidates_p, temp);
312- id = sample_token (&candidates_p, rng);
313303 }
304+ id = sample_token (&candidates_p, rng);
314305 }
315306
316307 return id;
@@ -952,6 +943,28 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
952943 std::mt19937 rng (params.seed );
953944 concat_output = " " ;
954945
946+ // prepare sampler order
947+ std::vector<samplers> sampler_order;
948+ if (inputs.sampler_len <=0 ) // list by value
949+ {
950+ sampler_order = {
951+ KCPP_SAMPLER_REP_PEN,
952+ KCPP_SAMPLER_TOP_K,
953+ KCPP_SAMPLER_TOP_A,
954+ KCPP_SAMPLER_TFS,
955+ KCPP_SAMPLER_TYP,
956+ KCPP_SAMPLER_TOP_P,
957+ KCPP_SAMPLER_TEMP
958+ };
959+ }
960+ else
961+ {
962+ for (int i=0 ;i<inputs.sampler_len ;++i)
963+ {
964+ sampler_order.push_back (inputs.sampler_order [i]);
965+ }
966+ }
967+
955968 bool startedsampling = false ;
956969 bool use_scratch = true ; // for normal inference always use scratch
957970
@@ -1274,8 +1287,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
12741287
12751288 id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
12761289 top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
1277- params.mirostat , params.mirostat_tau , params.mirostat_eta ,
1278- inputs.sampler_len , inputs.sampler_order );
1290+ params.mirostat , params.mirostat_tau , params.mirostat_eta , sampler_order);
12791291
12801292 last_n_tokens.erase (last_n_tokens.begin ());
12811293 last_n_tokens.push_back (id);
0 commit comments