44import threading
55import uuid
66from datetime import datetime
7+ from hashlib import sha256
78from typing import Any , Dict , List , Literal , Optional
89
910import litellm
11+ import pydantic
1012import ujson
13+ from cachetools import LRUCache , cached
1114
1215from dspy .adapters .base import Adapter
1316from dspy .clients .openai import OpenAIProvider
@@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
9295 completion = cached_litellm_text_completion if cache else litellm_text_completion
9396
9497 response = completion (
95- request = ujson . dumps ( dict (model = self .model , messages = messages , ** kwargs ) ),
98+ request = dict (model = self .model , messages = messages , ** kwargs ),
9699 num_retries = self .num_retries ,
97100 )
98101 outputs = [c .message .content if hasattr (c , "message" ) else c ["text" ] for c in response ["choices" ]]
@@ -153,7 +156,11 @@ def thread_function_wrapper():
153156 thread = threading .Thread (target = thread_function_wrapper )
154157 model_to_finetune = self .finetuning_model or self .model
155158 job = self .provider .TrainingJob (
156- thread = thread , model = model_to_finetune , train_data = train_data , train_kwargs = train_kwargs , data_format = data_format
159+ thread = thread ,
160+ model = model_to_finetune ,
161+ train_data = train_data ,
162+ train_kwargs = train_kwargs ,
163+ data_format = data_format ,
157164 )
158165 thread .start ()
159166
@@ -212,47 +219,81 @@ def copy(self, **kwargs):
212219 return new_instance
213220
214221
215- @functools .lru_cache (maxsize = None )
216- def cached_litellm_completion (request , num_retries : int ):
222+ def request_cache (maxsize : Optional [int ] = None ):
223+ """
224+ A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
225+ a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
226+ good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).
227+
228+ Args:
229+ maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).
230+
231+ Returns:
232+ A decorator that wraps the target function with caching.
233+ """
234+
235+ def cache_key (request : Dict [str , Any ]) -> str :
236+ # Transform Pydantic models into JSON-convertible format and exclude unhashable objects
237+ params = {k : (v .dict () if isinstance (v , pydantic .BaseModel ) else v ) for k , v in request .items ()}
238+ params = {k : v for k , v in params .items () if not callable (v )}
239+ return sha256 (ujson .dumps (params , sort_keys = True ).encode ()).hexdigest ()
240+
241+ def decorator (func ):
242+ @cached (
243+ # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
244+ cache = LRUCache (maxsize = maxsize or float ("inf" )),
245+ key = lambda request , * args , ** kwargs : cache_key (request ),
246+ # Use a lock to ensure thread safety for the cache when DSPy LMs are queried
247+ # concurrently, e.g. during optimization and evaluation
248+ lock = threading .Lock (),
249+ )
250+ @functools .wraps (func )
251+ def wrapper (request : dict , * args , ** kwargs ):
252+ return func (request , * args , ** kwargs )
253+
254+ return wrapper
255+
256+ return decorator
257+
258+
259+ @request_cache (maxsize = None )
260+ def cached_litellm_completion (request : Dict [str , Any ], num_retries : int ):
217261 return litellm_completion (
218262 request ,
219263 cache = {"no-cache" : False , "no-store" : False },
220264 num_retries = num_retries ,
221265 )
222266
223267
224- def litellm_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
225- kwargs = ujson .loads (request )
268+ def litellm_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
226269 return litellm .completion (
227270 num_retries = num_retries ,
228271 cache = cache ,
229- ** kwargs ,
272+ ** request ,
230273 )
231274
232275
233- @functools . lru_cache (maxsize = None )
234- def cached_litellm_text_completion (request , num_retries : int ):
276+ @request_cache (maxsize = None )
277+ def cached_litellm_text_completion (request : Dict [ str , Any ] , num_retries : int ):
235278 return litellm_text_completion (
236279 request ,
237280 num_retries = num_retries ,
238281 cache = {"no-cache" : False , "no-store" : False },
239282 )
240283
241284
242- def litellm_text_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
243- kwargs = ujson .loads (request )
244-
285+ def litellm_text_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
245286 # Extract the provider and model from the model string.
246287 # TODO: Not all the models are in the format of "provider/model"
247- model = kwargs .pop ("model" ).split ("/" , 1 )
288+ model = request .pop ("model" ).split ("/" , 1 )
248289 provider , model = model [0 ] if len (model ) > 1 else "openai" , model [- 1 ]
249290
250- # Use the API key and base from the kwargs , or from the environment.
251- api_key = kwargs .pop ("api_key" , None ) or os .getenv (f"{ provider } _API_KEY" )
252- api_base = kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
291+ # Use the API key and base from the request , or from the environment.
292+ api_key = request .pop ("api_key" , None ) or os .getenv (f"{ provider } _API_KEY" )
293+ api_base = request .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
253294
254295 # Build the prompt from the messages.
255- prompt = "\n \n " .join ([x ["content" ] for x in kwargs .pop ("messages" )] + ["BEGIN RESPONSE:" ])
296+ prompt = "\n \n " .join ([x ["content" ] for x in request .pop ("messages" )] + ["BEGIN RESPONSE:" ])
256297
257298 return litellm .text_completion (
258299 cache = cache ,
@@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
261302 api_base = api_base ,
262303 prompt = prompt ,
263304 num_retries = num_retries ,
264- ** kwargs ,
305+ ** request ,
265306 )
0 commit comments