22import pickle
33import threading
44from collections import OrderedDict
5- from typing import Any
5+ from collections .abc import Iterable
6+ from functools import wraps
7+ from typing import Any , Callable
68
79
810class PerSmilesPerModelLRUCache :
@@ -30,6 +32,7 @@ def get(self, smiles: str, model_name: str) -> Any | None:
3032 return None
3133
3234 def set (self , smiles : str , model_name : str , value : Any ) -> None :
35+ assert value is not None , "Value must not be None"
3336 key = (smiles , model_name )
3437 with self ._lock :
3538 if key in self ._cache :
@@ -50,9 +53,65 @@ def clear(self) -> None:
5053 def stats (self ) -> dict :
5154 return {"hits" : self .hits , "misses" : self .misses }
5255
56+ def batch_decorator (self , func : Callable ) -> Callable :
57+ """Decorator for class methods that accept a batch of SMILES as a tuple,
58+ and want caching per (smiles, model_name) combination.
59+ """
60+
61+ @wraps (func )
62+ def wrapper (instance , smiles_list : list [str ]):
63+ assert isinstance (smiles_list , list ), "smiles_list must be a list."
64+ model_name = getattr (instance , "model_name" , None )
65+ assert model_name is not None , "Instance must have a model_name attribute."
66+
67+ results = []
68+ missing_smiles = []
69+ missing_indices = []
70+
71+ # First: try to fetch all from cache
72+ for i , smiles in enumerate (smiles_list ):
73+ result = self .get (smiles = smiles , model_name = model_name )
74+ if result is not None :
75+ results .append ((i , result )) # save index for reordering
76+ else :
77+ missing_smiles .append (smiles )
78+ missing_indices .append (i )
79+
80+ # If some are missing, call original function
81+ if missing_smiles :
82+ new_results = func (instance , tuple (missing_smiles ))
83+ assert isinstance (
84+ new_results , Iterable
85+ ), "Function must return an Iterable."
86+ # Save to cache and append
87+ for smiles , prediction in zip (missing_smiles , new_results ):
88+ if prediction is not None :
89+ self .set (smiles , model_name , prediction )
90+ results .append ((missing_indices .pop (0 ), prediction ))
91+
92+ # Reorder results to match original indices
93+ results .sort (key = lambda x : x [0 ]) # sort by index
94+ ordered = [result for _ , result in results ]
95+ return ordered
96+
97+ return wrapper
98+
99+ def __len__ (self ):
100+ with self ._lock :
101+ return len (self ._cache )
102+
103+ def __repr__ (self ):
104+ return self ._cache .__repr__ ()
105+
106+ def save (self ):
107+ self ._save_cache ()
108+
109+ def load (self ):
110+ self ._load_cache ()
111+
53112 def _save_cache (self ) -> None :
54113 """Serialize the cache to disk."""
55- if not self ._persist_path :
114+ if self ._persist_path :
56115 try :
57116 with open (self ._persist_path , "wb" ) as f :
58117 pickle .dump (self ._cache , f )
@@ -72,5 +131,40 @@ def _load_cache(self) -> None:
72131
73132
74133if __name__ == "__main__" :
75- # Example usage
76- cache = PerSmilesPerModelLRUCache (max_size = 100 , persist_path = "cache.pkl" )
134+ # cache will persist across runs in "cache.pkl"
135+ cache = PerSmilesPerModelLRUCache (max_size = 50 )
136+
137+ class ExamplePredictor :
138+ model_name = "example_model"
139+
140+ @cache .batch_decorator
141+ def predict (self , smiles_list : tuple [str ]) -> list [dict ]:
142+ # Simulate a prediction function
143+ return [{"prediction" : hash (smiles ) % 100 } for smiles in smiles_list ]
144+
145+ # Create an instance of the predictor
146+ predictor = ExamplePredictor ()
147+
148+ # Prediction set 1 — new model, all should be cache misses
149+ predictor .model_name = "example_model"
150+ predictor .predict (["CCC" , "C" , "CCO" , "CCN" ]) # MISS × 4
151+ print ("Cache Stats:" , cache .stats ())
152+
153+ # Prediction set 2 — same model, partial hit/miss
154+ predictor .model_name = "example_model"
155+ predictor .predict (["CCC" , "CO" , "CCO" , "CN" ]) # HIT: CCC, CCO — MISS: CO, CN
156+ print ("Cache Stats:" , cache .stats ())
157+
158+ # Prediction set 3 — new model, same SMILES — should all be misses (per-model caching)
159+ predictor .model_name = "example_model_2"
160+ predictor .predict (["CCC" , "C" , "CO" , "CN" ]) # MISS × 4 (new model)
161+ print ("Cache Stats:" , cache .stats ())
162+
163+ # Prediction set 4 — another model
164+ predictor .model_name = "example_model_3"
165+ predictor .predict (["CCCC" , "CCCl" , "CCBr" , "C(C)C" ]) # MISS × 4
166+ print ("Cache Stats:" , cache .stats ())
167+
168+ from pprint import pprint
169+
170+ pprint (cache )
0 commit comments