@@ -36,14 +36,27 @@ int main(int argc, char ** argv) {
36
36
llama_context * ctx_tgt = NULL ;
37
37
llama_context * ctx_dft = NULL ;
38
38
39
+ bool self_speculation = false ;
40
+
39
41
// load the target model
40
42
params.logits_all = true ;
41
43
std::tie (model_tgt, ctx_tgt) = llama_init_from_gpt_params (params);
42
44
43
45
// load the draft model
44
- params.model = params.model_draft ;
45
- params.n_gpu_layers = params.n_gpu_layers_draft ;
46
- std::tie (model_dft, ctx_dft) = llama_init_from_gpt_params (params);
46
+ if (params.model != params.model_draft ) {
47
+ params.model = params.model_draft ;
48
+ params.n_gpu_layers = params.n_gpu_layers_draft ;
49
+ std::tie (model_dft, ctx_dft) = llama_init_from_gpt_params (params);
50
+ } else {
51
+ self_speculation = true ;
52
+ model_dft = model_tgt;
53
+ ctx_dft = ctx_tgt;
54
+ }
55
+
56
+ // the 2 models should have the same vocab
57
+ const int n_ctx = llama_n_ctx (ctx_tgt);
58
+ const int n_vocab = llama_n_vocab (model_tgt);
59
+ GGML_ASSERT (n_vocab == llama_n_vocab (model_dft));
47
60
48
61
// tokenize the prompt
49
62
std::vector<llama_token> inp;
@@ -68,6 +81,7 @@ int main(int argc, char ** argv) {
68
81
const int n_input = inp.size ();
69
82
70
83
llama_batch batch_dft = llama_batch_get_one (NULL , 0 , 0 , 0 );
84
+ llama_batch batch_tgt = llama_batch_get_one (NULL , 0 , 0 , 1 );
71
85
std::vector<int32_t > run_layers_dft = {
72
86
0 , 0 , 2 , 0 , 2 , 0 , 0 , 0 , 0 , 2 , 3 , 1 , 0 , 3 , 3 , 0 , 3 , 0 , 1 , 1 ,
73
87
3 , 3 , 3 , 0 , 2 , 3 , 2 , 3 , 3 , 3 , 1 , 3 , 0 , 0 , 2 , 1 , 0 , 2 , 0 , 0 ,
@@ -76,25 +90,39 @@ int main(int argc, char ** argv) {
76
90
77
91
const auto t_enc_start = ggml_time_us ();
78
92
79
- // eval the prompt with both models
80
- llama_decode (ctx_tgt, llama_batch_get_one ( inp.data (), n_input - 1 , 0 , 0 ));
81
- llama_decode (ctx_tgt, llama_batch_get_one (&inp.back (), 1 , n_input - 1 , 0 ));
93
+ std::vector<float > logits_tgt, logits_dft;
82
94
83
- batch_dft.n_tokens = n_input;
84
- batch_dft.token = inp.data ();
85
- batch_dft.all_pos_0 = 0 ;
86
- llama_decode (ctx_dft, batch_dft);
95
+ if (self_speculation) {
96
+ logits_tgt.resize (n_vocab * 30 );
97
+ logits_dft.resize (n_vocab);
98
+ }
99
+
100
+ // eval the prompt with both models
101
+ batch_tgt.n_tokens = n_input - 1 ;
102
+ batch_tgt.token = inp.data ();
103
+ batch_tgt.all_pos_0 = 0 ;
104
+ llama_decode (ctx_tgt, batch_tgt);
105
+ batch_tgt.n_tokens = 1 ;
106
+ batch_tgt.token = &inp.back ();
107
+ batch_tgt.all_pos_0 = n_input - 1 ;
108
+ llama_decode (ctx_tgt, batch_tgt);
109
+
110
+ if (!self_speculation) {
111
+ batch_dft.n_tokens = n_input;
112
+ batch_dft.token = inp.data ();
113
+ batch_dft.all_pos_0 = 0 ;
114
+ llama_decode (ctx_dft, batch_dft);
115
+ } else {
116
+ memcpy (logits_tgt.data (), llama_get_logits (ctx_tgt), sizeof (float ) * n_vocab);
117
+ memcpy (logits_dft.data (), llama_get_logits (ctx_tgt), sizeof (float ) * n_vocab);
118
+ llama_kv_cache_seq_cp (ctx_dft, 1 , 0 , 0 , -1 );
119
+ }
87
120
88
121
const auto t_enc_end = ggml_time_us ();
89
122
90
123
// Don't skip layers until after prompt eval.
91
124
batch_dft.run_layers = run_layers_dft.data ();
92
125
93
- // the 2 models should have the same vocab
94
- const int n_ctx = llama_n_ctx (ctx_tgt);
95
- const int n_vocab = llama_n_vocab (model_tgt);
96
- // GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
97
-
98
126
// how many tokens to draft each time
99
127
int n_draft = params.n_draft ;
100
128
@@ -150,7 +178,15 @@ int main(int argc, char ** argv) {
150
178
151
179
while (true ) {
152
180
// sample from the target model
153
- llama_token id = llama_sampling_sample (ctx_tgt, NULL , ctx_sampling, last_tokens, candidates, i_dft);
181
+ llama_token id;
182
+ if (!self_speculation) {
183
+ id = llama_sampling_sample (ctx_tgt, NULL , ctx_sampling, last_tokens, candidates, i_dft, 1 );
184
+ } else {
185
+ memcpy (llama_get_logits (ctx_tgt),
186
+ logits_tgt.data () + i_dft * n_vocab,
187
+ sizeof (float ) * size_t (n_vocab));
188
+ id = llama_sampling_sample (ctx_tgt, NULL , ctx_sampling, last_tokens, candidates, 0 , 1 );
189
+ }
154
190
155
191
// remember which tokens were sampled - used for repetition penalties during sampling
156
192
last_tokens.erase (last_tokens.begin ());
@@ -193,6 +229,11 @@ int main(int argc, char ** argv) {
193
229
batch_dft.n_tokens = 1 ;
194
230
batch_dft.all_pos_0 = n_past_dft;
195
231
llama_decode (ctx_dft, batch_dft);
232
+
233
+ if (self_speculation) {
234
+ memcpy (logits_dft.data (), llama_get_logits (ctx_dft), sizeof (float ) * n_vocab);
235
+ }
236
+
196
237
++n_past_dft;
197
238
198
239
// heuristic for n_draft
@@ -212,7 +253,7 @@ int main(int argc, char ** argv) {
212
253
LOG (" - partially drafted tokens accepted - no change\n " );
213
254
} else {
214
255
LOG (" - drafted token rejected - n_draft -= 1\n " );
215
- n_draft = std::max (2 , n_draft - 1 );
256
+ n_draft = std::max (6 , n_draft - 1 );
216
257
}
217
258
}
218
259
@@ -244,6 +285,10 @@ int main(int argc, char ** argv) {
244
285
// sample n_draft tokens from the draft model using greedy decoding
245
286
int n_past_cur = n_past_dft;
246
287
for (int i = 0 ; i < n_draft; ++i) {
288
+
289
+ if (self_speculation) {
290
+ memcpy (llama_get_logits (ctx_dft), logits_dft.data (), sizeof (float ) * n_vocab);
291
+ }
247
292
float * logits = llama_get_logits (ctx_dft);
248
293
249
294
candidates.clear ();
@@ -265,7 +310,13 @@ int main(int argc, char ** argv) {
265
310
}
266
311
267
312
// TODO: better logic?
268
- if (cur_p.data [0 ].p < 2 *cur_p.data [1 ].p ) {
313
+
314
+ // const float skip_scale = 1.25f + std::min(2.0f, 0.25f * float(i)); // 46.6
315
+ // const float skip_scale = 1.35f + std::min(2.5f, 0.15f * float(i)); // 48.48
316
+ // const float skip_scale = 1.35f + std::min(2.0f, 0.15f * float(i)); // 48.98
317
+ // const float skip_scale = 1.50f + std::min(2.0f, 0.10f * float(i)); // 51.64
318
+ const float skip_scale = 1 .50f + std::min (2 .0f , 0 .75f * float (i)); // 61.76
319
+ if (cur_p.data [0 ].p < skip_scale*cur_p.data [1 ].p ) {
269
320
LOG (" stopping drafting, probability too low: %.3f < 2*%.3f\n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
270
321
break ;
271
322
}
@@ -287,6 +338,11 @@ int main(int argc, char ** argv) {
287
338
batch_dft.n_tokens = 1 ;
288
339
batch_dft.all_pos_0 = n_past_cur;
289
340
llama_decode (ctx_dft, batch_dft);
341
+
342
+ if (self_speculation) {
343
+ memcpy (logits_dft.data (), llama_get_logits (ctx_dft), sizeof (float ) * n_vocab);
344
+ }
345
+
290
346
++n_past_cur;
291
347
292
348
if (grammar_dft != NULL ) {
@@ -295,8 +351,16 @@ int main(int argc, char ** argv) {
295
351
}
296
352
297
353
// evaluate the target model on the drafted tokens
298
- llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past_tgt, -1 );
299
- llama_decode (ctx_tgt, llama_batch_get_one (drafted.data (), drafted.size (), n_past_tgt, 0 ));
354
+ llama_kv_cache_seq_rm (ctx_tgt, 1 , n_past_tgt, -1 );
355
+ batch_tgt.n_tokens = drafted.size ();
356
+ batch_tgt.token = drafted.data ();
357
+ batch_tgt.all_pos_0 = n_past_tgt;
358
+ llama_decode (ctx_tgt, batch_tgt);
359
+
360
+ if (self_speculation) {
361
+ memcpy (logits_tgt.data (), llama_get_logits (ctx_tgt),
362
+ sizeof (float ) * n_vocab * size_t (batch_tgt.n_tokens ));
363
+ }
300
364
++n_past_tgt;
301
365
302
366
// the first token is always proposed by the traget model before the speculation loop
@@ -327,8 +391,10 @@ int main(int argc, char ** argv) {
327
391
llama_free (ctx_tgt);
328
392
llama_free_model (model_tgt);
329
393
330
- llama_free (ctx_dft);
331
- llama_free_model (model_dft);
394
+ if (!self_speculation) {
395
+ llama_free (ctx_dft);
396
+ llama_free_model (model_dft);
397
+ }
332
398
333
399
if (grammar_dft != NULL ) {
334
400
llama_grammar_free (grammar_dft);
0 commit comments