diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 4dffc17e1c..737de772ab 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -4,10 +4,13 @@ import threading import uuid from datetime import datetime +from hashlib import sha256 from typing import Any, Dict, List, Literal, Optional import litellm +import pydantic import ujson +from cachetools import LRUCache, cached from dspy.adapters.base import Adapter from dspy.clients.openai import OpenAIProvider @@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): completion = cached_litellm_text_completion if cache else litellm_text_completion response = completion( - request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), + request=dict(model=self.model, messages=messages, **kwargs), num_retries=self.num_retries, ) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -153,7 +156,11 @@ def thread_function_wrapper(): thread = threading.Thread(target=thread_function_wrapper) model_to_finetune = self.finetuning_model or self.model job = self.provider.TrainingJob( - thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format + thread=thread, + model=model_to_finetune, + train_data=train_data, + train_kwargs=train_kwargs, + data_format=data_format, ) thread.start() @@ -212,8 +219,45 @@ def copy(self, **kwargs): return new_instance -@functools.lru_cache(maxsize=None) -def cached_litellm_completion(request, num_retries: int): +def request_cache(maxsize: Optional[int] = None): + """ + A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept + a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring + good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow). + + Args: + maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded). + + Returns: + A decorator that wraps the target function with caching. + """ + + def cache_key(request: Dict[str, Any]) -> str: + # Transform Pydantic models into JSON-convertible format and exclude unhashable objects + params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()} + params = {k: v for k, v in params.items() if not callable(v)} + return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() + + def decorator(func): + @cached( + # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead + cache=LRUCache(maxsize=maxsize or float("inf")), + key=lambda request, *args, **kwargs: cache_key(request), + # Use a lock to ensure thread safety for the cache when DSPy LMs are queried + # concurrently, e.g. during optimization and evaluation + lock=threading.Lock(), + ) + @functools.wraps(func) + def wrapper(request: dict, *args, **kwargs): + return func(request, *args, **kwargs) + + return wrapper + + return decorator + + +@request_cache(maxsize=None) +def cached_litellm_completion(request: Dict[str, Any], num_retries: int): return litellm_completion( request, cache={"no-cache": False, "no-store": False}, @@ -221,17 +265,16 @@ def cached_litellm_completion(request, num_retries: int): ) -def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) +def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): return litellm.completion( num_retries=num_retries, cache=cache, - **kwargs, + **request, ) -@functools.lru_cache(maxsize=None) -def cached_litellm_text_completion(request, num_retries: int): +@request_cache(maxsize=None) +def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): return litellm_text_completion( request, num_retries=num_retries, @@ -239,20 +282,18 @@ def cached_litellm_text_completion(request, num_retries: int): ) -def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) - +def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" - model = kwargs.pop("model").split("/", 1) + model = request.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] - # Use the API key and base from the kwargs, or from the environment. - api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") - api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + # Use the API key and base from the request, or from the environment. + api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. - prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) + prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( cache=cache, @@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, api_base=api_base, prompt=prompt, num_retries=num_retries, - **kwargs, + **request, ) diff --git a/poetry.lock b/poetry.lock index af67dfb630..2e13c532bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -579,7 +579,7 @@ virtualenv = ["virtualenv (>=20.0.35)"] name = "cachetools" version = "5.5.0" description = "Extensible memoizing collections and decorators" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, @@ -8523,4 +8523,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "d7c6fe24c44afc0e694de8270035ac4cd9831f9bf9d6253f1cd1dbfe03ec2835" +content-hash = "f54a12cf2b3bc3d2811e8875561b1262c8b7a546d0fb333f075a1127bc2ffcb0" diff --git a/pyproject.toml b/pyproject.toml index 18cc52f728..c0fe14f192 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "tenacity>=8.2.3", "anyio", "asyncer==0.0.8", + "cachetools", ] [project.optional-dependencies] @@ -138,6 +139,7 @@ falkordb = "^1.0.9" json-repair = "^0.30.0" tenacity = ">=8.2.3" asyncer = "0.0.8" +cachetools = "^5.5.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.3" diff --git a/requirements.txt b/requirements.txt index 31874b1c9a..218ca9ed7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,7 @@ +anyio +asyncer==0.0.8 backoff +cachetools datasets diskcache httpx @@ -15,5 +18,3 @@ requests tenacity>=8.2.3 tqdm ujson -anyio -asyncer==0.0.8 diff --git a/tests/caching/test_caching.py b/tests/caching/test_caching.py index d5c3933742..468a81352a 100644 --- a/tests/caching/test_caching.py +++ b/tests/caching/test_caching.py @@ -88,3 +88,23 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te request_logs = read_litellm_test_server_request_logs(server_log_file_path) assert len(request_logs) == 0 + + +def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir): + api_base, server_log_file_path = litellm_test_server + + lm1 = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + ) + lm1("Example query") + # Remove the disk cache, after which the LM must rely on in-memory caching + shutil.rmtree(temporary_blank_cache_dir) + lm1("Example query2") + lm1("Example query2") + lm1("Example query2") + lm1("Example query2") + + request_logs = read_litellm_test_server_request_logs(server_log_file_path) + assert len(request_logs) == 2 diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 18c134601b..519c1e24e7 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,24 +1,77 @@ from unittest import mock +import pydantic import pytest import dspy from tests.test_utils.server import litellm_test_server -def test_lms_can_be_queried(litellm_test_server): +def test_chat_lms_can_be_queried(litellm_test_server): api_base, _ = litellm_test_server + expected_response = ["Hi!"] openai_lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", + model_type="chat", ) - openai_lm("openai query") + assert openai_lm("openai query") == expected_response azure_openai_lm = dspy.LM( model="azure/dspy-test-model", api_base=api_base, api_key="fakekey", + model_type="chat", ) - azure_openai_lm("azure openai query") + assert azure_openai_lm("azure openai query") == expected_response + + +def test_text_lms_can_be_queried(litellm_test_server): + api_base, _ = litellm_test_server + expected_response = ["Hi!"] + + openai_lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="text", + ) + assert openai_lm("openai query") == expected_response + + azure_openai_lm = dspy.LM( + model="azure/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="text", + ) + assert azure_openai_lm("azure openai query") == expected_response + + +def test_lm_calls_support_unhashable_types(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + lm_with_unhashable_callable = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + # Define a callable kwarg for the LM to use during inference + azure_ad_token_provider=lambda *args, **kwargs: None, + ) + lm_with_unhashable_callable("Query") + + +def test_lm_calls_support_pydantic_models(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + class ResponseFormat(pydantic.BaseModel): + response: str + + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + response_format=ResponseFormat, + ) + lm("Query")