Skip to content

Commit a2934a1

Browse files
committed
perf: add caching for embeddings
Add caching mechanism for embeddings to avoid re-computing them for the same content. This improves performance by storing and reusing previously computed embeddings based on filename and content hash. Signed-off-by: Tomas Slusny <[email protected]>
1 parent d177720 commit a2934a1

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

lua/CopilotChat/context.lua

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local off_side_rule_languages = {
4141
}
4242

4343
local big_file_threshold = 500
44-
local multi_file_threshold = 3
44+
local multi_file_threshold = 2
4545

4646
local function spatial_distance_cosine(a, b)
4747
local dot_product = 0
@@ -269,25 +269,19 @@ end
269269
---@param embeddings table<CopilotChat.copilot.embed>
270270
---@return table<CopilotChat.copilot.embed>
271271
function M.filter_embeddings(copilot, embeddings)
272-
-- If there is only query embedding or we are under the threshold, return embeddings without query
272+
-- If we dont need to embed anything, just return the embeddings without query
273273
if #embeddings <= (1 + multi_file_threshold) then
274274
table.remove(embeddings, 1)
275275
return embeddings
276276
end
277277

278+
-- Get embeddings
278279
local out = copilot:embed(embeddings)
279-
if #out <= 1 then
280-
return {}
281-
end
282-
283280
log.debug(string.format('Got %s embeddings', #out))
284281

285-
local query = table.remove(out, 1)
286-
log.debug('Query Prompt:', query.prompt)
287-
288-
local data = data_ranked_by_relatedness(query, out, 20)
282+
-- Rate embeddings by relatedness to the query
283+
local data = data_ranked_by_relatedness(table.remove(out, 1), out, 20)
289284
log.debug('Ranked data:', #data)
290-
291285
for i, item in ipairs(data) do
292286
log.debug(string.format('%s: %s - %s', i, item.score, item.filename))
293287
end

lua/CopilotChat/copilot.lua

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ local function machine_id()
104104
return hex
105105
end
106106

107+
local function quick_hash(str)
108+
return #str .. str:sub(1, 32) .. str:sub(-32)
109+
end
110+
107111
local function find_config_path()
108112
local config = vim.fn.expand('$XDG_CONFIG_HOME')
109113
if config and vim.fn.isdirectory(config) > 0 then
@@ -346,6 +350,7 @@ end
346350

347351
local Copilot = class(function(self, proxy, allow_insecure)
348352
self.history = {}
353+
self.embedding_cache = {}
349354
self.github_token = nil
350355
self.token = nil
351356
self.sessionid = nil
@@ -881,18 +886,34 @@ end
881886
---@param inputs table<CopilotChat.copilot.embed>: The inputs to embed
882887
---@param opts CopilotChat.copilot.embed.opts: Options for the request
883888
function Copilot:embed(inputs, opts)
884-
opts = opts or {}
885-
local model = opts.model or 'text-embedding-3-small'
886-
local chunk_size = opts.chunk_size or 15
887-
888889
if not inputs or #inputs == 0 then
889890
return {}
890891
end
891892

893+
-- Check which embeddings need to be fetched
894+
local cached_embeddings = {}
895+
local uncached_embeddings = {}
896+
for _, embed in ipairs(inputs) do
897+
if embed.content then
898+
local key = embed.filename .. quick_hash(embed.content)
899+
if self.embedding_cache[key] then
900+
table.insert(cached_embeddings, self.embedding_cache[key])
901+
else
902+
table.insert(uncached_embeddings, embed)
903+
end
904+
else
905+
table.insert(uncached_embeddings, embed)
906+
end
907+
end
908+
909+
opts = opts or {}
910+
local model = opts.model or 'text-embedding-3-small'
911+
local chunk_size = opts.chunk_size or 15
912+
892913
local out = {}
893914

894-
for i = 1, #inputs, chunk_size do
895-
local chunk = vim.list_slice(inputs, i, i + chunk_size - 1)
915+
for i = 1, #uncached_embeddings, chunk_size do
916+
local chunk = vim.list_slice(uncached_embeddings, i, i + chunk_size - 1)
896917
local body = vim.json.encode(generate_embedding_request(chunk, model))
897918
local response, err = curl_post(
898919
'https://api.githubcopilot.com/embeddings',
@@ -934,7 +955,16 @@ function Copilot:embed(inputs, opts)
934955
end
935956
end
936957

937-
return out
958+
-- Cache embeddings
959+
for _, embedding in ipairs(out) do
960+
if embedding.content then
961+
local key = embedding.filename .. quick_hash(embedding.content)
962+
self.embedding_cache[key] = embedding
963+
end
964+
end
965+
966+
-- Merge cached embeddings and newly fetched embeddings and return
967+
return vim.list_extend(out, cached_embeddings)
938968
end
939969

940970
--- Stop the running job
@@ -951,6 +981,7 @@ end
951981
function Copilot:reset()
952982
local stopped = self:stop()
953983
self.history = {}
984+
self.embedding_cache = {}
954985
return stopped
955986
end
956987

0 commit comments

Comments
 (0)