Skip to content

Commit cdd7de9

Browse files
committed
tests for custom cache
1 parent ad14325 commit cdd7de9

File tree

2 files changed

+132
-42
lines changed

2 files changed

+132
-42
lines changed

chebifier/_custom_cache.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def wrapper(instance, smiles_list: list[str]):
8484
new_results, Iterable
8585
), "Function must return an Iterable."
8686
# Save to cache and append
87-
for smiles, prediction in zip(missing_smiles, new_results):
87+
for smiles, prediction, missing_idx in zip(
88+
missing_smiles, new_results, missing_indices
89+
):
8890
if prediction is not None:
8991
self.set(smiles, model_name, prediction)
90-
results.append((missing_indices.pop(0), prediction))
92+
results.append((missing_idx, prediction))
9193

9294
# Reorder results to match original indices
9395
results.sort(key=lambda x: x[0]) # sort by index
@@ -131,43 +133,3 @@ def _load_cache(self) -> None:
131133
self._cache = loaded
132134
except Exception as e:
133135
print(f"[Cache Load Error] {e}")
134-
135-
136-
if __name__ == "__main__":
137-
# cache will persist across runs in "cache.pkl"
138-
cache = PerSmilesPerModelLRUCache(max_size=50)
139-
140-
class ExamplePredictor:
141-
model_name = "example_model"
142-
143-
@cache.batch_decorator
144-
def predict(self, smiles_list: tuple[str]) -> list[dict]:
145-
# Simulate a prediction function
146-
return [{"prediction": hash(smiles) % 100} for smiles in smiles_list]
147-
148-
# Create an instance of the predictor
149-
predictor = ExamplePredictor()
150-
151-
# Prediction set 1 — new model, all should be cache misses
152-
predictor.model_name = "example_model"
153-
predictor.predict(["CCC", "C", "CCO", "CCN"]) # MISS × 4
154-
print("Cache Stats:", cache.stats())
155-
156-
# Prediction set 2 — same model, partial hit/miss
157-
predictor.model_name = "example_model"
158-
predictor.predict(["CCC", "CO", "CCO", "CN"]) # HIT: CCC, CCO — MISS: CO, CN
159-
print("Cache Stats:", cache.stats())
160-
161-
# Prediction set 3 — new model, same SMILES — should all be misses (per-model caching)
162-
predictor.model_name = "example_model_2"
163-
predictor.predict(["CCC", "C", "CO", "CN"]) # MISS × 4 (new model)
164-
print("Cache Stats:", cache.stats())
165-
166-
# Prediction set 4 — another model
167-
predictor.model_name = "example_model_3"
168-
predictor.predict(["CCCC", "CCCl", "CCBr", "C(C)C"]) # MISS × 4
169-
print("Cache Stats:", cache.stats())
170-
171-
from pprint import pprint
172-
173-
pprint(cache)

tests/test_cache.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
5+
from chebifier import PerSmilesPerModelLRUCache
6+
7+
g_cache = PerSmilesPerModelLRUCache(max_size=3)
8+
9+
10+
class DummyPredictor:
11+
def __init__(self, model_name):
12+
self.model_name = model_name
13+
14+
@g_cache.batch_decorator
15+
def predict(self, smiles_list: tuple[str]):
16+
# Simple predictable dummy function for tests
17+
return [f"{self.model_name}{i}" for i in range(len(smiles_list))]
18+
19+
20+
class TestPerSmilesPerModelLRUCache(unittest.TestCase):
21+
def setUp(self):
22+
# Create temp file for persistence tests
23+
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
24+
self.temp_file.close()
25+
self.cache = PerSmilesPerModelLRUCache(
26+
max_size=3, persist_path=self.temp_file.name
27+
)
28+
29+
def tearDown(self):
30+
if os.path.exists(self.temp_file.name):
31+
os.remove(self.temp_file.name)
32+
33+
def test_cache_miss_and_set_get(self):
34+
# Initially empty
35+
self.assertEqual(len(self.cache), 0)
36+
self.assertIsNone(self.cache.get("CCC", "model1"))
37+
38+
# Set and get
39+
self.cache.set("CCC", "model1", "result1")
40+
self.assertEqual(self.cache.get("CCC", "model1"), "result1")
41+
self.assertEqual(self.cache.hits, 1)
42+
self.assertEqual(self.cache.misses, 1) # One miss from first get
43+
44+
def test_cache_eviction(self):
45+
self.cache.set("a", "m", "v1")
46+
self.cache.set("b", "m", "v2")
47+
self.cache.set("c", "m", "v3")
48+
self.assertEqual(len(self.cache), 3)
49+
# Adding one more triggers eviction of oldest
50+
self.cache.set("d", "m", "v4")
51+
self.assertEqual(len(self.cache), 3)
52+
self.assertIsNone(self.cache.get("a", "m")) # 'a' evicted
53+
self.assertIsNotNone(self.cache.get("d", "m")) # 'd' present
54+
55+
def test_batch_decorator_hits_and_misses(self):
56+
predictor = DummyPredictor("modelA")
57+
predictor2 = DummyPredictor("modelB")
58+
59+
# Clear cache before starting the test
60+
g_cache.clear()
61+
62+
smiles = ["AAA", "BBB", "CCC", "DDD", "EEE"]
63+
# First call all misses
64+
results1 = predictor.predict(smiles)
65+
results1_model2 = predictor2.predict(smiles)
66+
67+
# all prediction as retrived from actual prediction function and not from cache
68+
self.assertListEqual(
69+
results1, ["modelA_P0", "modelA_P1", "modelA_P2", "modelA_P3", "modelA_P4"]
70+
)
71+
self.assertListEqual(
72+
results1_model2,
73+
["modelB_P0", "modelB_P1", "modelB_P2", "modelB_P3", "modelB_P4"],
74+
)
75+
stats_after_first = g_cache.stats()
76+
self.assertEqual(stats_after_first["misses"], 3)
77+
78+
# cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2"}
79+
# Second call with some hits and some misses
80+
results2 = predictor.predict(["FFF", "DDD"])
81+
# AAA from cache
82+
# FFF is not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function
83+
# and dummy predictor returns iterates over the smiles list and return P{idx} corresponding to the index
84+
self.assertListEqual(results2, ["P3", "P0"])
85+
stats_after_second = g_cache.stats()
86+
self.assertEqual(stats_after_second["hits"], 1)
87+
self.assertEqual(stats_after_second["misses"], 4)
88+
89+
# cache = {("AAA", "modelA"): "P0", ("BBB", "modelA"): "P1", ("CCC", "modelA"): "P2",
90+
# ("DDD", "modelA"): "P3", ("EEE", "modelA"): "P4", ("FFF", "modelA"): "P0"}
91+
92+
# Third call with some hits and some misses
93+
results3 = predictor.predict(["EEE", "GGG", "DDD", "HHH", "BBB", "ZZZ"])
94+
# Here, predictions for [EEE, DDD, BBB] are retrived from cache,
95+
# while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function
96+
self.assertListEqual(results3, ["P4", "P0", "P3", "P0", "P1", "P0"])
97+
stats_after_third = g_cache.stats()
98+
self.assertEqual(stats_after_third["hits"], 1)
99+
self.assertEqual(stats_after_third["misses"], 4)
100+
101+
def test_persistence_save_and_load(self):
102+
# Set some values
103+
self.cache.set("sm1", "modelX", "val1")
104+
self.cache.set("sm2", "modelX", "val2")
105+
106+
# Save cache to file
107+
self.cache.save()
108+
109+
# Create new cache instance loading from file
110+
new_cache = PerSmilesPerModelLRUCache(
111+
max_size=3, persist_path=self.temp_file.name
112+
)
113+
new_cache.load()
114+
115+
self.assertEqual(new_cache.get("sm1", "modelX"), "val1")
116+
self.assertEqual(new_cache.get("sm2", "modelX"), "val2")
117+
118+
def test_clear_cache(self):
119+
self.cache.set("x", "m", "v")
120+
self.cache.save()
121+
self.assertTrue(os.path.exists(self.temp_file.name))
122+
self.cache.clear()
123+
self.assertEqual(len(self.cache), 0)
124+
self.assertFalse(os.path.exists(self.temp_file.name))
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)