From c0e5c5d10bf44ff4a05e11673e619f99177a4b17 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 00:59:04 -0800 Subject: [PATCH 1/7] Impl Signed-off-by: dbczumar --- dspy/clients/lm.py | 80 +++++++++++++++++++++++++++-------- tests/caching/test_caching.py | 49 +++++++++++++++++++++ 2 files changed, 111 insertions(+), 18 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 4dffc17e1c..203b7c3aba 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -4,10 +4,13 @@ import threading import uuid from datetime import datetime +from hashlib import sha256 from typing import Any, Dict, List, Literal, Optional import litellm +import pydantic import ujson +from cachetools import LRUCache, cached from dspy.adapters.base import Adapter from dspy.clients.openai import OpenAIProvider @@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): completion = cached_litellm_text_completion if cache else litellm_text_completion response = completion( - request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), + request=dict(model=self.model, messages=messages, **kwargs), num_retries=self.num_retries, ) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -153,7 +156,11 @@ def thread_function_wrapper(): thread = threading.Thread(target=thread_function_wrapper) model_to_finetune = self.finetuning_model or self.model job = self.provider.TrainingJob( - thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format + thread=thread, + model=model_to_finetune, + train_data=train_data, + train_kwargs=train_kwargs, + data_format=data_format, ) thread.start() @@ -212,8 +219,48 @@ def copy(self, **kwargs): return new_instance -@functools.lru_cache(maxsize=None) -def cached_litellm_completion(request, num_retries: int): +def request_cache(maxsize: Optional[int] = None): + """ + A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept + a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring + good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow). + + Args: + maxsize: The maximum size of the cache. + + Returns: + A decorator that wraps the target function with caching. + """ + # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead + cache = LRUCache(maxsize=maxsize or float("inf")) + cache_lock = threading.Lock() + + def cache_key(request: Dict[str, Any]) -> str: + # Transform Pydantic models into JSON-convertible format and exclude unhashable objects + params = { + k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items() if not callable(v) + } + return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() + + def decorator(func): + @cached( + cache=cache, + key=lambda request, *args, **kwargs: cache_key(request), + # Use a lock to ensure thread safety for the cache when DSPy LMs are queried + # concurrently, e.g. during optimization and evaluation + lock=cache_lock, + ) + @functools.wraps(func) + def wrapper(request: dict, *args, **kwargs): + return func(request, *args, **kwargs) + + return wrapper + + return decorator + + +@request_cache(maxsize=None) +def cached_litellm_completion(request: Dict[str, Any], num_retries: int): return litellm_completion( request, cache={"no-cache": False, "no-store": False}, @@ -221,17 +268,16 @@ def cached_litellm_completion(request, num_retries: int): ) -def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) +def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): return litellm.completion( num_retries=num_retries, cache=cache, - **kwargs, + **request, ) -@functools.lru_cache(maxsize=None) -def cached_litellm_text_completion(request, num_retries: int): +@request_cache(maxsize=None) +def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): return litellm_text_completion( request, num_retries=num_retries, @@ -239,20 +285,18 @@ def cached_litellm_text_completion(request, num_retries: int): ) -def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) - +def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" - model = kwargs.pop("model").split("/", 1) + model = request.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] - # Use the API key and base from the kwargs, or from the environment. - api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") - api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + # Use the API key and base from the request, or from the environment. + api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. - prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) + prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( cache=cache, @@ -261,5 +305,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, api_base=api_base, prompt=prompt, num_retries=num_retries, - **kwargs, + **request, ) diff --git a/tests/caching/test_caching.py b/tests/caching/test_caching.py index d5c3933742..30bc7dd278 100644 --- a/tests/caching/test_caching.py +++ b/tests/caching/test_caching.py @@ -3,6 +3,7 @@ import shutil import tempfile +import pydantic import pytest import dspy @@ -88,3 +89,51 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te request_logs = read_litellm_test_server_request_logs(server_log_file_path) assert len(request_logs) == 0 + + +def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir): + api_base, server_log_file_path = litellm_test_server + + lm1 = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + ) + lm1("Example query") + # Remove the disk cache, after which the LM must rely on in-memory caching + shutil.rmtree(temporary_blank_cache_dir) + lm1("Example query2") + lm1("Example query2") + lm1("Example query2") + lm1("Example query2") + + request_logs = read_litellm_test_server_request_logs(server_log_file_path) + assert len(request_logs) == 2 + + +def test_lm_calls_support_unhashable_types(litellm_test_server, temporary_blank_cache_dir): + api_base, server_log_file_path = litellm_test_server + + lm_with_unhashable_callable = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + # Define a callable kwarg for the LM to use during inference + azure_ad_token_provider=lambda *args, **kwargs: None, + ) + lm_with_unhashable_callable("Query") + + +def test_lm_calls_support_pydantic_models(litellm_test_server, temporary_blank_cache_dir): + api_base, server_log_file_path = litellm_test_server + + class ResponseFormat(pydantic.BaseModel): + response: str + + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + response_format=ResponseFormat, + ) + lm("Query") From 92f1525456b9722e4f4af3faf333abc178d7c7c7 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 01:03:09 -0800 Subject: [PATCH 2/7] Cachetools add Signed-off-by: dbczumar --- poetry.lock | 4 ++-- pyproject.toml | 2 ++ requirements.txt | 5 +++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index af67dfb630..2e13c532bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -579,7 +579,7 @@ virtualenv = ["virtualenv (>=20.0.35)"] name = "cachetools" version = "5.5.0" description = "Extensible memoizing collections and decorators" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, @@ -8523,4 +8523,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "d7c6fe24c44afc0e694de8270035ac4cd9831f9bf9d6253f1cd1dbfe03ec2835" +content-hash = "f54a12cf2b3bc3d2811e8875561b1262c8b7a546d0fb333f075a1127bc2ffcb0" diff --git a/pyproject.toml b/pyproject.toml index 18cc52f728..c0fe14f192 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "tenacity>=8.2.3", "anyio", "asyncer==0.0.8", + "cachetools", ] [project.optional-dependencies] @@ -138,6 +139,7 @@ falkordb = "^1.0.9" json-repair = "^0.30.0" tenacity = ">=8.2.3" asyncer = "0.0.8" +cachetools = "^5.5.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.3" diff --git a/requirements.txt b/requirements.txt index 31874b1c9a..218ca9ed7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,7 @@ +anyio +asyncer==0.0.8 backoff +cachetools datasets diskcache httpx @@ -15,5 +18,3 @@ requests tenacity>=8.2.3 tqdm ujson -anyio -asyncer==0.0.8 From 2c311a8f92db7868b3d395462d6b0a9c99f312b2 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 01:05:31 -0800 Subject: [PATCH 3/7] Inline Signed-off-by: dbczumar --- dspy/clients/lm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 203b7c3aba..55f29f16b0 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -231,9 +231,6 @@ def request_cache(maxsize: Optional[int] = None): Returns: A decorator that wraps the target function with caching. """ - # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead - cache = LRUCache(maxsize=maxsize or float("inf")) - cache_lock = threading.Lock() def cache_key(request: Dict[str, Any]) -> str: # Transform Pydantic models into JSON-convertible format and exclude unhashable objects @@ -244,11 +241,12 @@ def cache_key(request: Dict[str, Any]) -> str: def decorator(func): @cached( - cache=cache, + # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead + cache=LRUCache(maxsize=maxsize or float("inf")), key=lambda request, *args, **kwargs: cache_key(request), # Use a lock to ensure thread safety for the cache when DSPy LMs are queried # concurrently, e.g. during optimization and evaluation - lock=cache_lock, + lock=threading.Lock(), ) @functools.wraps(func) def wrapper(request: dict, *args, **kwargs): From c543086f44ee03ca3a0fb8457396574066eb00fa Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 01:26:57 -0800 Subject: [PATCH 4/7] tweak Signed-off-by: dbczumar --- dspy/clients/lm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 55f29f16b0..8b1fadfc7d 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -234,9 +234,8 @@ def request_cache(maxsize: Optional[int] = None): def cache_key(request: Dict[str, Any]) -> str: # Transform Pydantic models into JSON-convertible format and exclude unhashable objects - params = { - k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items() if not callable(v) - } + params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()} + params = {k: v for k, v in params.items() if not callable(v)} return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() def decorator(func): From 954a773f2455f51fea4bc203a31bb43d0a62f0a6 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 02:09:12 -0800 Subject: [PATCH 5/7] fix Signed-off-by: dbczumar --- tests/caching/test_caching.py | 29 ----------------------------- tests/clients/test_lm.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/caching/test_caching.py b/tests/caching/test_caching.py index 30bc7dd278..468a81352a 100644 --- a/tests/caching/test_caching.py +++ b/tests/caching/test_caching.py @@ -3,7 +3,6 @@ import shutil import tempfile -import pydantic import pytest import dspy @@ -109,31 +108,3 @@ def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, tempor request_logs = read_litellm_test_server_request_logs(server_log_file_path) assert len(request_logs) == 2 - - -def test_lm_calls_support_unhashable_types(litellm_test_server, temporary_blank_cache_dir): - api_base, server_log_file_path = litellm_test_server - - lm_with_unhashable_callable = dspy.LM( - model="openai/dspy-test-model", - api_base=api_base, - api_key="fakekey", - # Define a callable kwarg for the LM to use during inference - azure_ad_token_provider=lambda *args, **kwargs: None, - ) - lm_with_unhashable_callable("Query") - - -def test_lm_calls_support_pydantic_models(litellm_test_server, temporary_blank_cache_dir): - api_base, server_log_file_path = litellm_test_server - - class ResponseFormat(pydantic.BaseModel): - response: str - - lm = dspy.LM( - model="openai/dspy-test-model", - api_base=api_base, - api_key="fakekey", - response_format=ResponseFormat, - ) - lm("Query") diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 18c134601b..77c86512f1 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,5 +1,6 @@ from unittest import mock +import pydantic import pytest import dspy @@ -22,3 +23,31 @@ def test_lms_can_be_queried(litellm_test_server): api_key="fakekey", ) azure_openai_lm("azure openai query") + + +def test_lm_calls_support_unhashable_types(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + lm_with_unhashable_callable = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + # Define a callable kwarg for the LM to use during inference + azure_ad_token_provider=lambda *args, **kwargs: None, + ) + lm_with_unhashable_callable("Query") + + +def test_lm_calls_support_pydantic_models(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + class ResponseFormat(pydantic.BaseModel): + response: str + + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + response_format=ResponseFormat, + ) + lm("Query") From 935ce1ed6a6bfb30731b301af56fae20b2831f22 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Fri, 6 Dec 2024 02:14:09 -0800 Subject: [PATCH 6/7] fix Signed-off-by: dbczumar --- tests/clients/test_lm.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 77c86512f1..519c1e24e7 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -7,22 +7,46 @@ from tests.test_utils.server import litellm_test_server -def test_lms_can_be_queried(litellm_test_server): +def test_chat_lms_can_be_queried(litellm_test_server): api_base, _ = litellm_test_server + expected_response = ["Hi!"] openai_lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", + model_type="chat", ) - openai_lm("openai query") + assert openai_lm("openai query") == expected_response azure_openai_lm = dspy.LM( model="azure/dspy-test-model", api_base=api_base, api_key="fakekey", + model_type="chat", ) - azure_openai_lm("azure openai query") + assert azure_openai_lm("azure openai query") == expected_response + + +def test_text_lms_can_be_queried(litellm_test_server): + api_base, _ = litellm_test_server + expected_response = ["Hi!"] + + openai_lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="text", + ) + assert openai_lm("openai query") == expected_response + + azure_openai_lm = dspy.LM( + model="azure/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="text", + ) + assert azure_openai_lm("azure openai query") == expected_response def test_lm_calls_support_unhashable_types(litellm_test_server): From a98c7fe6dd34dcf4559e10384c0f1af2d160d01a Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:30:50 -0800 Subject: [PATCH 7/7] Update lm.py --- dspy/clients/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 8b1fadfc7d..737de772ab 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -226,7 +226,7 @@ def request_cache(maxsize: Optional[int] = None): good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow). Args: - maxsize: The maximum size of the cache. + maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded). Returns: A decorator that wraps the target function with caching.