99#include " sampling.h"
1010#include " speculative.h"
1111#include " mtmd.h"
12- #include " mtmd-helper.h"
1312
1413// mime type for sending response
1514#define MIMETYPE_JSON " application/json; charset=utf-8"
@@ -1439,6 +1438,9 @@ struct server_prompt_cache {
14391438 // in bytes, 0 = no limit
14401439 size_t limit_size = 2ull *1024 *1024 *1024 ;
14411440
1441+ // in tokens, 0 = no limit
1442+ size_t limit_tokens = 0 ;
1443+
14421444 size_t size () const {
14431445 size_t res = 0 ;
14441446
@@ -1449,15 +1451,51 @@ struct server_prompt_cache {
14491451 return res;
14501452 }
14511453
1452- int n_tokens () const {
1453- int res = 0 ;
1454+ size_t n_tokens () const {
1455+ size_t res = 0 ;
14541456
14551457 for (const auto & state : states) {
14561458 res += state.n_tokens ();
14571459 }
14581460
14591461 return res;
14601462 }
1463+
1464+ void update () {
1465+ // always keep at least one state, regardless of the limits
1466+ if (states.size () > 1 ) {
1467+ if (limit_size > 0 ) {
1468+ while (size () > limit_size) {
1469+ if (states.empty ()) {
1470+ break ;
1471+ }
1472+
1473+ SRV_WRN (" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n " , states.front ().size () / (1024.0 * 1024.0 ));
1474+
1475+ states.pop_front ();
1476+ }
1477+ }
1478+
1479+ if (limit_tokens > 0 ) {
1480+ while (n_tokens () > limit_tokens) {
1481+ if (states.empty ()) {
1482+ break ;
1483+ }
1484+
1485+ SRV_WRN (" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n " , states.front ().size () / (1024.0 * 1024.0 ));
1486+
1487+ states.pop_front ();
1488+ }
1489+ }
1490+ }
1491+
1492+ SRV_WRN (" - cache state: %zu prompts, %.3f MiB, limits: %.3f MiB, %zu tokens\n " ,
1493+ states.size (), size () / (1024.0 * 1024.0 ), limit_size / (1024.0 * 1024.0 ), limit_tokens);
1494+
1495+ for (const auto & state : states) {
1496+ SRV_WRN (" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n " , (const void *)&state, state.n_tokens (), state.checkpoints .size (), state.size () / (1024.0 * 1024.0 ));
1497+ }
1498+ }
14611499};
14621500
14631501struct server_slot {
@@ -1805,7 +1843,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18051843 const int len = cached_prompt.get_common_prefix (prompt.tokens );
18061844
18071845 if (len == (int ) cached_prompt.size ()) {
1808- SRV_WRN (" - removing cached prompt with length %d\n " , len);
1846+ SRV_WRN (" - removing obsolete cached prompt with length %d\n " , len);
18091847
18101848 it = states.erase (it);
18111849 } else {
@@ -1815,33 +1853,9 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18151853
18161854 const size_t cur_size = llama_state_seq_get_size_ext (ctx, id, 0 );
18171855
1818- SRV_WRN (" - saving prompt with length %d, total cache size = %.3f MiB\n " ,
1856+ SRV_WRN (" - saving prompt with length %d, total state size = %.3f MiB\n " ,
18191857 (int ) prompt.tokens .size (), cur_size / (1024.0 * 1024.0 ));
18201858
1821- // if there is a limit, remove the oldest entries to make room
1822- if (prompt_cache.limit_size > 0 ) {
1823- while (prompt_cache.size () + cur_size > prompt_cache.limit_size ) {
1824- if (states.empty ()) {
1825- break ;
1826- }
1827-
1828- SRV_WRN (" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n " , states.front ().size () / (1024.0 * 1024.0 ));
1829-
1830- states.pop_front ();
1831- }
1832- } else {
1833- // else, make sure the number of cached tokens doesn't exceed the context size of the slot
1834- while (prompt_cache.n_tokens () + (int ) prompt.tokens .size () > n_ctx) {
1835- if (states.empty ()) {
1836- break ;
1837- }
1838-
1839- SRV_WRN (" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n " , states.front ().size () / (1024.0 * 1024.0 ));
1840-
1841- states.pop_front ();
1842- }
1843- }
1844-
18451859 // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
18461860 auto & cur = states.emplace_back ();
18471861 cur = {
@@ -1851,12 +1865,6 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18511865 };
18521866
18531867 llama_state_seq_get_data_ext (ctx, cur.data .data (), cur_size, id, 0 );
1854-
1855- SRV_WRN (" - cache state: %zu prompts, %.3f MiB\n " , states.size (), prompt_cache.size () / (1024.0 * 1024.0 ));
1856-
1857- for (const auto & state : states) {
1858- SRV_WRN (" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n " , (const void *)&state, state.n_tokens (), state.checkpoints .size (), state.size () / (1024.0 * 1024.0 ));
1859- }
18601868}
18611869
18621870void server_slot::prompt_load (server_prompt_cache & prompt_cache, const server_tokens & tokens) {
@@ -2611,6 +2619,8 @@ struct server_context {
26112619 ret->prompt_save (prompt_cache);
26122620 ret->prompt_load (prompt_cache, task.tokens );
26132621
2622+ prompt_cache.update ();
2623+
26142624 SRV_WRN (" prompt cache update took %.2f ms\n " , (ggml_time_us () - t_start) / 1000.0 );
26152625 }
26162626 }
0 commit comments