Skip to content

Commit ad14325

Browse files
committed
decorate each predict method with cache
1 parent ed28dfe commit ad14325

File tree

7 files changed

+34
-30
lines changed

7 files changed

+34
-30
lines changed

chebifier/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Note: The top-level package __init__.py runs only once,
2+
# even if multiple subpackages are imported later.
3+
4+
from ._custom_cache import PerSmilesPerModelLRUCache
5+
6+
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)

chebifier/ensemble/_custom_cache.py renamed to chebifier/_custom_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def wrapper(instance, smiles_list: list[str]):
9292
# Reorder results to match original indices
9393
results.sort(key=lambda x: x[0]) # sort by index
9494
ordered = [result for _, result in results]
95+
assert len(ordered) == len(
96+
smiles_list
97+
), "Result length does not match input length."
9598
return ordered
9699

97100
return wrapper

chebifier/prediction_models/base_predictor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from abc import ABC
33

4-
from functools import lru_cache
4+
from chebifier import modelwise_smiles_lru_cache
55

66

77
class BasePredictor(ABC):
@@ -23,17 +23,13 @@ def __init__(
2323

2424
self._description = kwargs.get("description", None)
2525

26+
@modelwise_smiles_lru_cache.batch_decorator
2627
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
27-
# list is not hashable, so we convert it to a tuple (useful for caching)
28-
return self.predict_smiles_tuple(tuple(smiles_list))
29-
30-
@lru_cache(maxsize=100)
31-
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
3228
raise NotImplementedError()
3329

3430
def predict_smiles(self, smiles: str) -> dict:
3531
# by default, use list-based prediction
36-
return self.predict_smiles_tuple((smiles,))[0]
32+
return self.predict_smiles_list([smiles])[0]
3733

3834
@property
3935
def info_text(self):

chebifier/prediction_models/c3p_predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from functools import lru_cache
2-
from typing import Optional, List
31
from pathlib import Path
2+
from typing import List, Optional
43

54
from c3p import classifier as c3p_classifier
65

6+
from chebifier import modelwise_smiles_lru_cache
77
from chebifier.prediction_models import BasePredictor
88

99

@@ -24,8 +24,8 @@ def __init__(
2424
self.chemical_classes = chemical_classes
2525
self.chebi_graph = kwargs.get("chebi_graph", None)
2626

27-
@lru_cache(maxsize=100)
28-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
27+
@modelwise_smiles_lru_cache.batch_decorator
28+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
2929
result_list = c3p_classifier.classify(
3030
list(smiles_list),
3131
self.program_directory,

chebifier/prediction_models/chebi_lookup.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from functools import lru_cache
1+
import json
2+
import os
23
from typing import Optional
34

4-
from chebifier.prediction_models import BasePredictor
5-
import os
65
import networkx as nx
76
from rdkit import Chem
8-
import json
7+
8+
from chebifier import modelwise_smiles_lru_cache
9+
from chebifier.prediction_models import BasePredictor
910
from chebifier.utils import load_chebi_graph
1011

1112

1213
class ChEBILookupPredictor(BasePredictor):
13-
1414
def __init__(
1515
self,
1616
model_name: str,
@@ -67,7 +67,6 @@ def build_smiles_lookup(self):
6767
)
6868
return smiles_lookup
6969

70-
@lru_cache(maxsize=100)
7170
def predict_smiles(self, smiles: str) -> Optional[dict]:
7271
if not smiles:
7372
return None
@@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
9493
else:
9594
return None
9695

97-
def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
96+
@modelwise_smiles_lru_cache.batch_decorator
97+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
9898
predictions = []
9999
for smiles in smiles_list:
100100
predictions.append(self.predict_smiles(smiles))
@@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict:
145145
# Example usage
146146
smiles_list = [
147147
"CCO",
148-
"C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
148+
"C1=CC=CC=C1",
149+
"*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
149150
] # SMILES with 251 matches in ChEBI
150151
predictions = predictor.predict_smiles_list(smiles_list)
151152
print(predictions)

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
)
1313
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
1414
from chemlog_extra.alg_classification.by_element_classification import (
15-
XMolecularEntityClassifier,
1615
OrganoXCompoundClassifier,
16+
XMolecularEntityClassifier,
1717
)
18-
from functools import lru_cache
18+
19+
from chebifier import modelwise_smiles_lru_cache
1920

2021
from .base_predictor import BasePredictor
2122

@@ -47,7 +48,6 @@
4748

4849

4950
class ChemlogExtraPredictor(BasePredictor):
50-
5151
CHEMLOG_CLASSIFIER = None
5252

5353
def __init__(self, model_name: str, **kwargs):
@@ -72,12 +72,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
7272

7373

7474
class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
75-
7675
CHEMLOG_CLASSIFIER = XMolecularEntityClassifier
7776

7877

7978
class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):
80-
8179
CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier
8280

8381

@@ -97,7 +95,6 @@ def __init__(self, model_name: str, **kwargs):
9795
# fmt: on
9896
print(f"Initialised ChemLog model {self.model_name}")
9997

100-
@lru_cache(maxsize=100)
10198
def predict_smiles(self, smiles: str) -> Optional[dict]:
10299
mol = _smiles_to_mol(smiles)
103100
if mol is None:
@@ -122,7 +119,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
122119
for label in self.peptide_labels + pos_labels
123120
}
124121

125-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
122+
@modelwise_smiles_lru_cache.batch_decorator
123+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
126124
results = []
127125
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
128126
results.append(self.predict_smiles(smiles))

chebifier/prediction_models/nn_predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from functools import lru_cache
2-
31
import numpy as np
42
import torch
53
import tqdm
64
from rdkit import Chem
75

6+
from chebifier import modelwise_smiles_lru_cache
7+
88
from .base_predictor import BasePredictor
99

1010

@@ -52,8 +52,8 @@ def read_smiles(self, smiles):
5252
d = reader.to_data(dict(features=smiles, labels=None))
5353
return d
5454

55-
@lru_cache(maxsize=100)
56-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
55+
@modelwise_smiles_lru_cache.batch_decorator
56+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
5757
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
5858
Of classes and predicted values."""
5959
token_dicts = []

0 commit comments

Comments
 (0)