Skip to content

Commit 05ba850

Browse files
committed
server : improve prompt caching logic
1 parent cf7dd4b commit 05ba850

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

tools/server/server.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "sampling.h"
1010
#include "speculative.h"
1111
#include "mtmd.h"
12-
#include "mtmd-helper.h"
1312

1413
// mime type for sending response
1514
#define MIMETYPE_JSON "application/json; charset=utf-8"
@@ -1439,6 +1438,9 @@ struct server_prompt_cache {
14391438
// in bytes, 0 = no limit
14401439
size_t limit_size = 2ull*1024*1024*1024;
14411440

1441+
// in tokens, 0 = no limit
1442+
size_t limit_tokens = 0;
1443+
14421444
size_t size() const {
14431445
size_t res = 0;
14441446

@@ -1449,15 +1451,51 @@ struct server_prompt_cache {
14491451
return res;
14501452
}
14511453

1452-
int n_tokens() const {
1453-
int res = 0;
1454+
size_t n_tokens() const {
1455+
size_t res = 0;
14541456

14551457
for (const auto & state : states) {
14561458
res += state.n_tokens();
14571459
}
14581460

14591461
return res;
14601462
}
1463+
1464+
void update() {
1465+
// always keep at least one state, regardless of the limits
1466+
if (states.size() > 1) {
1467+
if (limit_size > 0) {
1468+
while (size() > limit_size) {
1469+
if (states.empty()) {
1470+
break;
1471+
}
1472+
1473+
SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1474+
1475+
states.pop_front();
1476+
}
1477+
}
1478+
1479+
if (limit_tokens > 0) {
1480+
while (n_tokens() > limit_tokens) {
1481+
if (states.empty()) {
1482+
break;
1483+
}
1484+
1485+
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1486+
1487+
states.pop_front();
1488+
}
1489+
}
1490+
}
1491+
1492+
SRV_WRN(" - cache state: %zu prompts, %.3f MiB, limits: %.3f MiB, %zu tokens\n",
1493+
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
1494+
1495+
for (const auto & state : states) {
1496+
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
1497+
}
1498+
}
14611499
};
14621500

14631501
struct server_slot {
@@ -1805,7 +1843,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18051843
const int len = cached_prompt.get_common_prefix(prompt.tokens);
18061844

18071845
if (len == (int) cached_prompt.size()) {
1808-
SRV_WRN(" - removing cached prompt with length %d\n", len);
1846+
SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
18091847

18101848
it = states.erase(it);
18111849
} else {
@@ -1815,33 +1853,9 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18151853

18161854
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
18171855

1818-
SRV_WRN(" - saving prompt with length %d, total cache size = %.3f MiB\n",
1856+
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
18191857
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
18201858

1821-
// if there is a limit, remove the oldest entries to make room
1822-
if (prompt_cache.limit_size > 0) {
1823-
while (prompt_cache.size() + cur_size > prompt_cache.limit_size) {
1824-
if (states.empty()) {
1825-
break;
1826-
}
1827-
1828-
SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1829-
1830-
states.pop_front();
1831-
}
1832-
} else {
1833-
// else, make sure the number of cached tokens doesn't exceed the context size of the slot
1834-
while (prompt_cache.n_tokens() + (int) prompt.tokens.size() > n_ctx) {
1835-
if (states.empty()) {
1836-
break;
1837-
}
1838-
1839-
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1840-
1841-
states.pop_front();
1842-
}
1843-
}
1844-
18451859
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
18461860
auto & cur = states.emplace_back();
18471861
cur = {
@@ -1851,12 +1865,6 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18511865
};
18521866

18531867
llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0);
1854-
1855-
SRV_WRN(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0));
1856-
1857-
for (const auto & state : states) {
1858-
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
1859-
}
18601868
}
18611869

18621870
void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
@@ -2611,6 +2619,8 @@ struct server_context {
26112619
ret->prompt_save(prompt_cache);
26122620
ret->prompt_load(prompt_cache, task.tokens);
26132621

2622+
prompt_cache.update();
2623+
26142624
SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
26152625
}
26162626
}

0 commit comments

Comments
 (0)