Skip to content

Commit 317910c

Browse files
committed
Self-speculate with only one context
Possible improvements to speculation logic
1 parent 4e38013 commit 317910c

File tree

1 file changed

+88
-22
lines changed

1 file changed

+88
-22
lines changed

examples/speculative/speculative.cpp

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,27 @@ int main(int argc, char ** argv) {
3636
llama_context * ctx_tgt = NULL;
3737
llama_context * ctx_dft = NULL;
3838

39+
bool self_speculation = false;
40+
3941
// load the target model
4042
params.logits_all = true;
4143
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
4244

4345
// load the draft model
44-
params.model = params.model_draft;
45-
params.n_gpu_layers = params.n_gpu_layers_draft;
46-
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
46+
if (params.model != params.model_draft) {
47+
params.model = params.model_draft;
48+
params.n_gpu_layers = params.n_gpu_layers_draft;
49+
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
50+
} else {
51+
self_speculation = true;
52+
model_dft = model_tgt;
53+
ctx_dft = ctx_tgt;
54+
}
55+
56+
// the 2 models should have the same vocab
57+
const int n_ctx = llama_n_ctx(ctx_tgt);
58+
const int n_vocab = llama_n_vocab(model_tgt);
59+
GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
4760

4861
// tokenize the prompt
4962
std::vector<llama_token> inp;
@@ -68,6 +81,7 @@ int main(int argc, char ** argv) {
6881
const int n_input = inp.size();
6982

7083
llama_batch batch_dft = llama_batch_get_one(NULL, 0, 0, 0);
84+
llama_batch batch_tgt = llama_batch_get_one(NULL, 0, 0, 1);
7185
std::vector<int32_t> run_layers_dft = {
7286
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
7387
3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
@@ -76,25 +90,39 @@ int main(int argc, char ** argv) {
7690

7791
const auto t_enc_start = ggml_time_us();
7892

79-
// eval the prompt with both models
80-
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
81-
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
93+
std::vector<float> logits_tgt, logits_dft;
8294

83-
batch_dft.n_tokens = n_input;
84-
batch_dft.token = inp.data();
85-
batch_dft.all_pos_0 = 0;
86-
llama_decode(ctx_dft, batch_dft);
95+
if (self_speculation) {
96+
logits_tgt.resize(n_vocab * 30);
97+
logits_dft.resize(n_vocab);
98+
}
99+
100+
// eval the prompt with both models
101+
batch_tgt.n_tokens = n_input - 1;
102+
batch_tgt.token = inp.data();
103+
batch_tgt.all_pos_0 = 0;
104+
llama_decode(ctx_tgt, batch_tgt);
105+
batch_tgt.n_tokens = 1;
106+
batch_tgt.token = &inp.back();
107+
batch_tgt.all_pos_0 = n_input - 1;
108+
llama_decode(ctx_tgt, batch_tgt);
109+
110+
if (!self_speculation) {
111+
batch_dft.n_tokens = n_input;
112+
batch_dft.token = inp.data();
113+
batch_dft.all_pos_0 = 0;
114+
llama_decode(ctx_dft, batch_dft);
115+
} else {
116+
memcpy(logits_tgt.data(), llama_get_logits(ctx_tgt), sizeof(float) * n_vocab);
117+
memcpy(logits_dft.data(), llama_get_logits(ctx_tgt), sizeof(float) * n_vocab);
118+
llama_kv_cache_seq_cp(ctx_dft, 1, 0, 0, -1);
119+
}
87120

88121
const auto t_enc_end = ggml_time_us();
89122

90123
// Don't skip layers until after prompt eval.
91124
batch_dft.run_layers = run_layers_dft.data();
92125

93-
// the 2 models should have the same vocab
94-
const int n_ctx = llama_n_ctx(ctx_tgt);
95-
const int n_vocab = llama_n_vocab(model_tgt);
96-
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
97-
98126
// how many tokens to draft each time
99127
int n_draft = params.n_draft;
100128

@@ -150,7 +178,15 @@ int main(int argc, char ** argv) {
150178

151179
while (true) {
152180
// sample from the target model
153-
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
181+
llama_token id;
182+
if (!self_speculation) {
183+
id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft, 1);
184+
} else {
185+
memcpy(llama_get_logits(ctx_tgt),
186+
logits_tgt.data() + i_dft * n_vocab,
187+
sizeof(float) * size_t(n_vocab));
188+
id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, 0, 1);
189+
}
154190

155191
// remember which tokens were sampled - used for repetition penalties during sampling
156192
last_tokens.erase(last_tokens.begin());
@@ -193,6 +229,11 @@ int main(int argc, char ** argv) {
193229
batch_dft.n_tokens = 1;
194230
batch_dft.all_pos_0 = n_past_dft;
195231
llama_decode(ctx_dft, batch_dft);
232+
233+
if (self_speculation) {
234+
memcpy(logits_dft.data(), llama_get_logits(ctx_dft), sizeof(float) * n_vocab);
235+
}
236+
196237
++n_past_dft;
197238

198239
// heuristic for n_draft
@@ -212,7 +253,7 @@ int main(int argc, char ** argv) {
212253
LOG(" - partially drafted tokens accepted - no change\n");
213254
} else {
214255
LOG(" - drafted token rejected - n_draft -= 1\n");
215-
n_draft = std::max(2, n_draft - 1);
256+
n_draft = std::max(6, n_draft - 1);
216257
}
217258
}
218259

@@ -244,6 +285,10 @@ int main(int argc, char ** argv) {
244285
// sample n_draft tokens from the draft model using greedy decoding
245286
int n_past_cur = n_past_dft;
246287
for (int i = 0; i < n_draft; ++i) {
288+
289+
if (self_speculation) {
290+
memcpy(llama_get_logits(ctx_dft), logits_dft.data(), sizeof(float) * n_vocab);
291+
}
247292
float * logits = llama_get_logits(ctx_dft);
248293

249294
candidates.clear();
@@ -265,7 +310,13 @@ int main(int argc, char ** argv) {
265310
}
266311

267312
// TODO: better logic?
268-
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
313+
314+
// const float skip_scale = 1.25f + std::min(2.0f, 0.25f * float(i)); // 46.6
315+
// const float skip_scale = 1.35f + std::min(2.5f, 0.15f * float(i)); // 48.48
316+
// const float skip_scale = 1.35f + std::min(2.0f, 0.15f * float(i)); // 48.98
317+
// const float skip_scale = 1.50f + std::min(2.0f, 0.10f * float(i)); // 51.64
318+
const float skip_scale = 1.50f + std::min(2.0f, 0.75f * float(i)); // 61.76
319+
if (cur_p.data[0].p < skip_scale*cur_p.data[1].p) {
269320
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
270321
break;
271322
}
@@ -287,6 +338,11 @@ int main(int argc, char ** argv) {
287338
batch_dft.n_tokens = 1;
288339
batch_dft.all_pos_0 = n_past_cur;
289340
llama_decode(ctx_dft, batch_dft);
341+
342+
if (self_speculation) {
343+
memcpy(logits_dft.data(), llama_get_logits(ctx_dft), sizeof(float) * n_vocab);
344+
}
345+
290346
++n_past_cur;
291347

292348
if (grammar_dft != NULL) {
@@ -295,8 +351,16 @@ int main(int argc, char ** argv) {
295351
}
296352

297353
// evaluate the target model on the drafted tokens
298-
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
299-
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
354+
llama_kv_cache_seq_rm(ctx_tgt, 1, n_past_tgt, -1);
355+
batch_tgt.n_tokens = drafted.size();
356+
batch_tgt.token = drafted.data();
357+
batch_tgt.all_pos_0 = n_past_tgt;
358+
llama_decode(ctx_tgt, batch_tgt);
359+
360+
if (self_speculation) {
361+
memcpy(logits_tgt.data(), llama_get_logits(ctx_tgt),
362+
sizeof(float) * n_vocab * size_t(batch_tgt.n_tokens));
363+
}
300364
++n_past_tgt;
301365

302366
// the first token is always proposed by the traget model before the speculation loop
@@ -327,8 +391,10 @@ int main(int argc, char ** argv) {
327391
llama_free(ctx_tgt);
328392
llama_free_model(model_tgt);
329393

330-
llama_free(ctx_dft);
331-
llama_free_model(model_dft);
394+
if (!self_speculation) {
395+
llama_free(ctx_dft);
396+
llama_free_model(model_dft);
397+
}
332398

333399
if (grammar_dft != NULL) {
334400
llama_grammar_free(grammar_dft);

0 commit comments

Comments
 (0)