Skip to content

Commit 3f1aaf1

Browse files
committed
dynamic imports for nn models
1 parent 955ab83 commit 3f1aaf1

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

chebifier/prediction_models/electra_predictor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import numpy as np
2-
from chebai.models.electra import Electra
3-
from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader
42

53
from .nn_predictor import NNPredictor
64

@@ -37,10 +35,14 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
3735

3836
class ElectraPredictor(NNPredictor):
3937
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
38+
from chebai.preprocessing.reader import ChemDataReader
39+
4040
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
4141
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
4242

43-
def init_model(self, ckpt_path: str, **kwargs) -> Electra:
43+
def init_model(self, ckpt_path: str, **kwargs) -> "Electra": # noqa: F821
44+
from chebai.models.electra import Electra
45+
4446
model = Electra.load_from_checkpoint(
4547
ckpt_path,
4648
map_location=self.device,
@@ -53,6 +55,8 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra:
5355
return model
5456

5557
def explain_smiles(self, smiles) -> dict:
58+
from chebai.preprocessing.reader import EMBEDDING_OFFSET
59+
5660
reader = self.reader_cls()
5761
token_dict = reader.to_data(dict(features=smiles, labels=None))
5862
tokens = np.array(token_dict["features"]).astype(int).tolist()

chebifier/prediction_models/gnn_predictor.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
import chebai_graph.preprocessing.properties as p
2-
import torch
3-
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
4-
from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder
5-
from chebai_graph.preprocessing.reader import GraphPropertyReader
6-
from torch_geometric.data.data import Data as GeomData
7-
81
from .nn_predictor import NNPredictor
92

103

114
class ResGatedPredictor(NNPredictor):
125
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
6+
from chebai_graph.preprocessing.properties import MolecularProperty
7+
from chebai_graph.preprocessing.reader import GraphPropertyReader
8+
139
super().__init__(
1410
model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs
1511
)
@@ -23,7 +19,7 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
2319
properties = []
2420
self.molecular_properties = properties
2521
assert isinstance(self.molecular_properties, list) and all(
26-
isinstance(prop, p.MolecularProperty) for prop in self.molecular_properties
22+
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
2723
)
2824
print(f"Initialised GNN model {self.model_name} (device: {self.device})")
2925

@@ -32,7 +28,12 @@ def load_class(self, class_path: str):
3228
module = __import__(module_path, fromlist=[class_name])
3329
return getattr(module, class_name)
3430

35-
def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
31+
def init_model(
32+
self, ckpt_path: str, **kwargs
33+
) -> "ResGatedGraphConvNetGraphPred": # noqa: F821
34+
import torch
35+
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
36+
3637
model = ResGatedGraphConvNetGraphPred.load_from_checkpoint(
3738
ckpt_path,
3839
map_location=torch.device(self.device),
@@ -45,6 +46,14 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
4546
return model
4647

4748
def read_smiles(self, smiles):
49+
import torch
50+
from chebai_graph.preprocessing.properties import AtomProperty, BondProperty
51+
from chebai_graph.preprocessing.property_encoder import (
52+
IndexEncoder,
53+
OneHotEncoder,
54+
)
55+
from torch_geometric.data.data import Data as GeomData
56+
4857
reader = self.reader_cls()
4958
d = reader.to_data(dict(features=smiles, labels=None))
5059
geom_data = d["features"]
@@ -87,9 +96,9 @@ def read_smiles(self, smiles):
8796
encoded_values = encoded_values.unsqueeze(1)
8897
else:
8998
encoded_values = torch.zeros((0, prop.encoder.get_encoding_length()))
90-
if isinstance(prop, p.AtomProperty):
99+
if isinstance(prop, AtomProperty):
91100
x = torch.cat([x, encoded_values], dim=1)
92-
elif isinstance(prop, p.BondProperty):
101+
elif isinstance(prop, BondProperty):
93102
edge_attr = torch.cat([edge_attr, encoded_values], dim=1)
94103
else:
95104
molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1)

chebifier/prediction_models/nn_predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from functools import lru_cache
22

33
import numpy as np
4-
import torch
54
import tqdm
65
from rdkit import Chem
76

@@ -17,6 +16,8 @@ def __init__(
1716
target_labels_path: str,
1817
**kwargs,
1918
):
19+
import torch
20+
2021
super().__init__(model_name, **kwargs)
2122
self.reader_cls = reader_cls
2223

@@ -56,6 +57,8 @@ def read_smiles(self, smiles):
5657
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
5758
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
5859
Of classes and predicted values."""
60+
import torch
61+
5962
token_dicts = []
6063
could_not_parse = []
6164
index_map = dict()

0 commit comments

Comments
 (0)