Skip to content

Commit a8d1107

Browse files
authored
Integrate cachetools for in-memory LM caching, including unhashable types & pydantic (#1896)
* Impl Signed-off-by: dbczumar <[email protected]> * Cachetools add Signed-off-by: dbczumar <[email protected]> * Inline Signed-off-by: dbczumar <[email protected]> * tweak Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update lm.py --------- Signed-off-by: dbczumar <[email protected]>
1 parent 5365fe9 commit a8d1107

File tree

6 files changed

+142
-25
lines changed

6 files changed

+142
-25
lines changed

dspy/clients/lm.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import threading
55
import uuid
66
from datetime import datetime
7+
from hashlib import sha256
78
from typing import Any, Dict, List, Literal, Optional
89

910
import litellm
11+
import pydantic
1012
import ujson
13+
from cachetools import LRUCache, cached
1114

1215
from dspy.adapters.base import Adapter
1316
from dspy.clients.openai import OpenAIProvider
@@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
9295
completion = cached_litellm_text_completion if cache else litellm_text_completion
9396

9497
response = completion(
95-
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
98+
request=dict(model=self.model, messages=messages, **kwargs),
9699
num_retries=self.num_retries,
97100
)
98101
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
@@ -153,7 +156,11 @@ def thread_function_wrapper():
153156
thread = threading.Thread(target=thread_function_wrapper)
154157
model_to_finetune = self.finetuning_model or self.model
155158
job = self.provider.TrainingJob(
156-
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format
159+
thread=thread,
160+
model=model_to_finetune,
161+
train_data=train_data,
162+
train_kwargs=train_kwargs,
163+
data_format=data_format,
157164
)
158165
thread.start()
159166

@@ -212,47 +219,81 @@ def copy(self, **kwargs):
212219
return new_instance
213220

214221

215-
@functools.lru_cache(maxsize=None)
216-
def cached_litellm_completion(request, num_retries: int):
222+
def request_cache(maxsize: Optional[int] = None):
223+
"""
224+
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
225+
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
226+
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).
227+
228+
Args:
229+
maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).
230+
231+
Returns:
232+
A decorator that wraps the target function with caching.
233+
"""
234+
235+
def cache_key(request: Dict[str, Any]) -> str:
236+
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects
237+
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()}
238+
params = {k: v for k, v in params.items() if not callable(v)}
239+
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
240+
241+
def decorator(func):
242+
@cached(
243+
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
244+
cache=LRUCache(maxsize=maxsize or float("inf")),
245+
key=lambda request, *args, **kwargs: cache_key(request),
246+
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
247+
# concurrently, e.g. during optimization and evaluation
248+
lock=threading.Lock(),
249+
)
250+
@functools.wraps(func)
251+
def wrapper(request: dict, *args, **kwargs):
252+
return func(request, *args, **kwargs)
253+
254+
return wrapper
255+
256+
return decorator
257+
258+
259+
@request_cache(maxsize=None)
260+
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
217261
return litellm_completion(
218262
request,
219263
cache={"no-cache": False, "no-store": False},
220264
num_retries=num_retries,
221265
)
222266

223267

224-
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
225-
kwargs = ujson.loads(request)
268+
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
226269
return litellm.completion(
227270
num_retries=num_retries,
228271
cache=cache,
229-
**kwargs,
272+
**request,
230273
)
231274

232275

233-
@functools.lru_cache(maxsize=None)
234-
def cached_litellm_text_completion(request, num_retries: int):
276+
@request_cache(maxsize=None)
277+
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
235278
return litellm_text_completion(
236279
request,
237280
num_retries=num_retries,
238281
cache={"no-cache": False, "no-store": False},
239282
)
240283

241284

242-
def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
243-
kwargs = ujson.loads(request)
244-
285+
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
245286
# Extract the provider and model from the model string.
246287
# TODO: Not all the models are in the format of "provider/model"
247-
model = kwargs.pop("model").split("/", 1)
288+
model = request.pop("model").split("/", 1)
248289
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
249290

250-
# Use the API key and base from the kwargs, or from the environment.
251-
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
252-
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
291+
# Use the API key and base from the request, or from the environment.
292+
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
293+
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
253294

254295
# Build the prompt from the messages.
255-
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
296+
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
256297

257298
return litellm.text_completion(
258299
cache=cache,
@@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
261302
api_base=api_base,
262303
prompt=prompt,
263304
num_retries=num_retries,
264-
**kwargs,
305+
**request,
265306
)

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies = [
4444
"tenacity>=8.2.3",
4545
"anyio",
4646
"asyncer==0.0.8",
47+
"cachetools",
4748
]
4849

4950
[project.optional-dependencies]
@@ -138,6 +139,7 @@ falkordb = "^1.0.9"
138139
json-repair = "^0.30.0"
139140
tenacity = ">=8.2.3"
140141
asyncer = "0.0.8"
142+
cachetools = "^5.5.0"
141143

142144
[tool.poetry.group.dev.dependencies]
143145
pytest = "^8.3.3"

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
anyio
2+
asyncer==0.0.8
13
backoff
4+
cachetools
25
datasets
36
diskcache
47
httpx
@@ -15,5 +18,3 @@ requests
1518
tenacity>=8.2.3
1619
tqdm
1720
ujson
18-
anyio
19-
asyncer==0.0.8

tests/caching/test_caching.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,23 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te
8888

8989
request_logs = read_litellm_test_server_request_logs(server_log_file_path)
9090
assert len(request_logs) == 0
91+
92+
93+
def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir):
94+
api_base, server_log_file_path = litellm_test_server
95+
96+
lm1 = dspy.LM(
97+
model="openai/dspy-test-model",
98+
api_base=api_base,
99+
api_key="fakekey",
100+
)
101+
lm1("Example query")
102+
# Remove the disk cache, after which the LM must rely on in-memory caching
103+
shutil.rmtree(temporary_blank_cache_dir)
104+
lm1("Example query2")
105+
lm1("Example query2")
106+
lm1("Example query2")
107+
lm1("Example query2")
108+
109+
request_logs = read_litellm_test_server_request_logs(server_log_file_path)
110+
assert len(request_logs) == 2

tests/clients/test_lm.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,77 @@
11
from unittest import mock
22

3+
import pydantic
34
import pytest
45

56
import dspy
67
from tests.test_utils.server import litellm_test_server
78

89

9-
def test_lms_can_be_queried(litellm_test_server):
10+
def test_chat_lms_can_be_queried(litellm_test_server):
1011
api_base, _ = litellm_test_server
12+
expected_response = ["Hi!"]
1113

1214
openai_lm = dspy.LM(
1315
model="openai/dspy-test-model",
1416
api_base=api_base,
1517
api_key="fakekey",
18+
model_type="chat",
1619
)
17-
openai_lm("openai query")
20+
assert openai_lm("openai query") == expected_response
1821

1922
azure_openai_lm = dspy.LM(
2023
model="azure/dspy-test-model",
2124
api_base=api_base,
2225
api_key="fakekey",
26+
model_type="chat",
2327
)
24-
azure_openai_lm("azure openai query")
28+
assert azure_openai_lm("azure openai query") == expected_response
29+
30+
31+
def test_text_lms_can_be_queried(litellm_test_server):
32+
api_base, _ = litellm_test_server
33+
expected_response = ["Hi!"]
34+
35+
openai_lm = dspy.LM(
36+
model="openai/dspy-test-model",
37+
api_base=api_base,
38+
api_key="fakekey",
39+
model_type="text",
40+
)
41+
assert openai_lm("openai query") == expected_response
42+
43+
azure_openai_lm = dspy.LM(
44+
model="azure/dspy-test-model",
45+
api_base=api_base,
46+
api_key="fakekey",
47+
model_type="text",
48+
)
49+
assert azure_openai_lm("azure openai query") == expected_response
50+
51+
52+
def test_lm_calls_support_unhashable_types(litellm_test_server):
53+
api_base, server_log_file_path = litellm_test_server
54+
55+
lm_with_unhashable_callable = dspy.LM(
56+
model="openai/dspy-test-model",
57+
api_base=api_base,
58+
api_key="fakekey",
59+
# Define a callable kwarg for the LM to use during inference
60+
azure_ad_token_provider=lambda *args, **kwargs: None,
61+
)
62+
lm_with_unhashable_callable("Query")
63+
64+
65+
def test_lm_calls_support_pydantic_models(litellm_test_server):
66+
api_base, server_log_file_path = litellm_test_server
67+
68+
class ResponseFormat(pydantic.BaseModel):
69+
response: str
70+
71+
lm = dspy.LM(
72+
model="openai/dspy-test-model",
73+
api_base=api_base,
74+
api_key="fakekey",
75+
response_format=ResponseFormat,
76+
)
77+
lm("Query")

0 commit comments

Comments
 (0)