1- import chebai_graph .preprocessing .properties as p
2- import torch
3- from chebai_graph .models .graph import ResGatedGraphConvNetGraphPred
4- from chebai_graph .preprocessing .property_encoder import IndexEncoder , OneHotEncoder
5- from chebai_graph .preprocessing .reader import GraphPropertyReader
6- from torch_geometric .data .data import Data as GeomData
7-
81from .nn_predictor import NNPredictor
92
103
114class ResGatedPredictor (NNPredictor ):
125 def __init__ (self , model_name : str , ckpt_path : str , molecular_properties , ** kwargs ):
6+ from chebai_graph .preprocessing .properties import MolecularProperty
7+ from chebai_graph .preprocessing .reader import GraphPropertyReader
8+
139 super ().__init__ (
1410 model_name , ckpt_path , reader_cls = GraphPropertyReader , ** kwargs
1511 )
@@ -23,7 +19,7 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
2319 properties = []
2420 self .molecular_properties = properties
2521 assert isinstance (self .molecular_properties , list ) and all (
26- isinstance (prop , p . MolecularProperty ) for prop in self .molecular_properties
22+ isinstance (prop , MolecularProperty ) for prop in self .molecular_properties
2723 )
2824 print (f"Initialised GNN model { self .model_name } (device: { self .device } )" )
2925
@@ -32,7 +28,12 @@ def load_class(self, class_path: str):
3228 module = __import__ (module_path , fromlist = [class_name ])
3329 return getattr (module , class_name )
3430
35- def init_model (self , ckpt_path : str , ** kwargs ) -> ResGatedGraphConvNetGraphPred :
31+ def init_model (
32+ self , ckpt_path : str , ** kwargs
33+ ) -> "ResGatedGraphConvNetGraphPred" : # noqa: F821
34+ import torch
35+ from chebai_graph .models .graph import ResGatedGraphConvNetGraphPred
36+
3637 model = ResGatedGraphConvNetGraphPred .load_from_checkpoint (
3738 ckpt_path ,
3839 map_location = torch .device (self .device ),
@@ -45,6 +46,14 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
4546 return model
4647
4748 def read_smiles (self , smiles ):
49+ import torch
50+ from chebai_graph .preprocessing .properties import AtomProperty , BondProperty
51+ from chebai_graph .preprocessing .property_encoder import (
52+ IndexEncoder ,
53+ OneHotEncoder ,
54+ )
55+ from torch_geometric .data .data import Data as GeomData
56+
4857 reader = self .reader_cls ()
4958 d = reader .to_data (dict (features = smiles , labels = None ))
5059 geom_data = d ["features" ]
@@ -87,9 +96,9 @@ def read_smiles(self, smiles):
8796 encoded_values = encoded_values .unsqueeze (1 )
8897 else :
8998 encoded_values = torch .zeros ((0 , prop .encoder .get_encoding_length ()))
90- if isinstance (prop , p . AtomProperty ):
99+ if isinstance (prop , AtomProperty ):
91100 x = torch .cat ([x , encoded_values ], dim = 1 )
92- elif isinstance (prop , p . BondProperty ):
101+ elif isinstance (prop , BondProperty ):
93102 edge_attr = torch .cat ([edge_attr , encoded_values ], dim = 1 )
94103 else :
95104 molecule_attr = torch .cat ([molecule_attr , encoded_values [0 ]], dim = 1 )
0 commit comments