File tree Expand file tree Collapse file tree 3 files changed +29
-4
lines changed Expand file tree Collapse file tree 3 files changed +29
-4
lines changed Original file line number Diff line number Diff line change @@ -136,7 +136,10 @@ class llama_kv_cache_unified_state : public llama_kv_cache_unified_state_i {
136136 std::vector<uint32_t > heads;
137137 std::vector<llama_ubatch> ubatches;
138138
139+ //
139140 // data needed for building the compute graph for the current ubatch:
141+ //
142+
140143 // a heuristic, to avoid attending the full cache if it is not yet utilized
141144 // as the cache gets filled, the benefit from this heuristic disappears
142145 int32_t n_kv;
@@ -1876,7 +1879,10 @@ class llama_kv_cache_unified_iswa_state : public llama_kv_cache_unified_iswa_sta
18761879
18771880 std::vector<llama_ubatch> ubatches;
18781881
1882+ //
18791883 // data needed for building the compute graph for the current ubatch:
1884+ //
1885+
18801886 int32_t n_kv_base;
18811887 int32_t head_base;
18821888
@@ -2123,7 +2129,7 @@ class llama_kv_cache_recurrent_state_t : public llama_kv_cache_recurrent_state_i
21232129 return kv->s_copy (i);
21242130 }
21252131
2126- float s_mask (int i) const override {
2132+ float s_mask (int i) const override {
21272133 return kv->s_mask (i);
21282134 }
21292135
@@ -2132,13 +2138,18 @@ class llama_kv_cache_recurrent_state_t : public llama_kv_cache_recurrent_state_i
21322138
21332139 llama_kv_cache_recurrent * kv;
21342140
2135- const bool is_full = false ;
2136-
21372141 llama_sbatch sbatch;
21382142
21392143 size_t i_next = 0 ;
21402144
21412145 std::vector<llama_ubatch> ubatches;
2146+
2147+ //
2148+ // data needed for building the compute graph for the current ubatch:
2149+ // TODO: extract all the state like `head` and `n` here
2150+ //
2151+
2152+ const bool is_full = false ;
21422153};
21432154
21442155llama_kv_cache_recurrent::llama_kv_cache_recurrent (
Original file line number Diff line number Diff line change @@ -40,6 +40,9 @@ struct llama_kv_cache : public llama_memory_i {
4040 virtual bool update (llama_context & lctx) = 0;
4141
4242 // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
43+ // TODO: change to
44+ // llama_memory_state_ptr init_defrag(float thold) = 0;
45+ //
4346 virtual void defrag_sched (float thold) = 0;
4447
4548 // getters
@@ -253,7 +256,7 @@ class llama_kv_cache_unified_state_i : public llama_memory_state_i {
253256 virtual ggml_tensor * get_k (ggml_context * ctx, int32_t il) const = 0;
254257 virtual ggml_tensor * get_v (ggml_context * ctx, int32_t il) const = 0;
255258
256- // store k_cur and v_cur in the cache based on the current head location
259+ // store k_cur and v_cur in the cache based on the provided head location
257260 virtual ggml_tensor * cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const = 0;
258261 virtual ggml_tensor * cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const = 0;
259262
@@ -359,6 +362,8 @@ class llama_kv_cache_unified_iswa_state_i : public llama_memory_state_i {
359362// llama_kv_cache_recurrent
360363//
361364
365+ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
366+ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
362367class llama_kv_cache_recurrent : public llama_kv_cache {
363368public:
364369 llama_kv_cache_recurrent (
Original file line number Diff line number Diff line change @@ -42,6 +42,15 @@ enum llama_memory_status {
4242 LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
4343};
4444
45+ // the interface for managing the memory state during batch processing
46+ // this interface is extended per memory type with specific methods used for constructing the compute graphs. see:
47+ // - llama_kv_cache_unified_state_i
48+ // - llama_kv_cache_unified_iswa_state_i
49+ // ...
50+ //
51+ // these extended interfaces should not mutate neither the memory, nor the current memory state
52+ // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
53+ //
4554class llama_memory_state_i {
4655public:
4756 virtual ~llama_memory_state_i () = default ;
You can’t perform that action at this time.
0 commit comments