Skip to content

Commit ed28dfe

Browse files
committed
add decorator for custom cache
1 parent 4095634 commit ed28dfe

File tree

1 file changed

+98
-4
lines changed

1 file changed

+98
-4
lines changed

chebifier/ensemble/_custom_cache.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import pickle
33
import threading
44
from collections import OrderedDict
5-
from typing import Any
5+
from collections.abc import Iterable
6+
from functools import wraps
7+
from typing import Any, Callable
68

79

810
class PerSmilesPerModelLRUCache:
@@ -30,6 +32,7 @@ def get(self, smiles: str, model_name: str) -> Any | None:
3032
return None
3133

3234
def set(self, smiles: str, model_name: str, value: Any) -> None:
35+
assert value is not None, "Value must not be None"
3336
key = (smiles, model_name)
3437
with self._lock:
3538
if key in self._cache:
@@ -50,9 +53,65 @@ def clear(self) -> None:
5053
def stats(self) -> dict:
5154
return {"hits": self.hits, "misses": self.misses}
5255

56+
def batch_decorator(self, func: Callable) -> Callable:
57+
"""Decorator for class methods that accept a batch of SMILES as a tuple,
58+
and want caching per (smiles, model_name) combination.
59+
"""
60+
61+
@wraps(func)
62+
def wrapper(instance, smiles_list: list[str]):
63+
assert isinstance(smiles_list, list), "smiles_list must be a list."
64+
model_name = getattr(instance, "model_name", None)
65+
assert model_name is not None, "Instance must have a model_name attribute."
66+
67+
results = []
68+
missing_smiles = []
69+
missing_indices = []
70+
71+
# First: try to fetch all from cache
72+
for i, smiles in enumerate(smiles_list):
73+
result = self.get(smiles=smiles, model_name=model_name)
74+
if result is not None:
75+
results.append((i, result)) # save index for reordering
76+
else:
77+
missing_smiles.append(smiles)
78+
missing_indices.append(i)
79+
80+
# If some are missing, call original function
81+
if missing_smiles:
82+
new_results = func(instance, tuple(missing_smiles))
83+
assert isinstance(
84+
new_results, Iterable
85+
), "Function must return an Iterable."
86+
# Save to cache and append
87+
for smiles, prediction in zip(missing_smiles, new_results):
88+
if prediction is not None:
89+
self.set(smiles, model_name, prediction)
90+
results.append((missing_indices.pop(0), prediction))
91+
92+
# Reorder results to match original indices
93+
results.sort(key=lambda x: x[0]) # sort by index
94+
ordered = [result for _, result in results]
95+
return ordered
96+
97+
return wrapper
98+
99+
def __len__(self):
100+
with self._lock:
101+
return len(self._cache)
102+
103+
def __repr__(self):
104+
return self._cache.__repr__()
105+
106+
def save(self):
107+
self._save_cache()
108+
109+
def load(self):
110+
self._load_cache()
111+
53112
def _save_cache(self) -> None:
54113
"""Serialize the cache to disk."""
55-
if not self._persist_path:
114+
if self._persist_path:
56115
try:
57116
with open(self._persist_path, "wb") as f:
58117
pickle.dump(self._cache, f)
@@ -72,5 +131,40 @@ def _load_cache(self) -> None:
72131

73132

74133
if __name__ == "__main__":
75-
# Example usage
76-
cache = PerSmilesPerModelLRUCache(max_size=100, persist_path="cache.pkl")
134+
# cache will persist across runs in "cache.pkl"
135+
cache = PerSmilesPerModelLRUCache(max_size=50)
136+
137+
class ExamplePredictor:
138+
model_name = "example_model"
139+
140+
@cache.batch_decorator
141+
def predict(self, smiles_list: tuple[str]) -> list[dict]:
142+
# Simulate a prediction function
143+
return [{"prediction": hash(smiles) % 100} for smiles in smiles_list]
144+
145+
# Create an instance of the predictor
146+
predictor = ExamplePredictor()
147+
148+
# Prediction set 1 — new model, all should be cache misses
149+
predictor.model_name = "example_model"
150+
predictor.predict(["CCC", "C", "CCO", "CCN"]) # MISS × 4
151+
print("Cache Stats:", cache.stats())
152+
153+
# Prediction set 2 — same model, partial hit/miss
154+
predictor.model_name = "example_model"
155+
predictor.predict(["CCC", "CO", "CCO", "CN"]) # HIT: CCC, CCO — MISS: CO, CN
156+
print("Cache Stats:", cache.stats())
157+
158+
# Prediction set 3 — new model, same SMILES — should all be misses (per-model caching)
159+
predictor.model_name = "example_model_2"
160+
predictor.predict(["CCC", "C", "CO", "CN"]) # MISS × 4 (new model)
161+
print("Cache Stats:", cache.stats())
162+
163+
# Prediction set 4 — another model
164+
predictor.model_name = "example_model_3"
165+
predictor.predict(["CCCC", "CCCl", "CCBr", "C(C)C"]) # MISS × 4
166+
print("Cache Stats:", cache.stats())
167+
168+
from pprint import pprint
169+
170+
pprint(cache)

0 commit comments

Comments
 (0)