1212from torchmdnet .module import LNNP
1313from torchmdnet .data import DataModule
1414from torchmdnet import priors
15+ import os
1516
1617from utils import load_example_args , DummyDataset
1718
@@ -29,7 +30,20 @@ def test_load_model():
2930@mark .parametrize ("model_name" , models .__all_models__ )
3031@mark .parametrize ("use_atomref" , [True , False ])
3132@mark .parametrize ("precision" , [32 , 64 ])
33+ @mark .skipif (
34+ os .getenv ("LONG_TRAIN" , "false" ) == "false" , reason = "Skipping long train test"
35+ )
3236def test_train (model_name , use_atomref , precision , tmpdir ):
37+ import torch
38+
39+ torch .set_num_threads (1 )
40+
41+ accelerator = "auto"
42+ if os .getenv ("CPU_TRAIN" , "false" ) == "true" :
43+ # OSX MPS backend runs out of memory on Github Actions
44+ torch .set_default_device ("cpu" )
45+ accelerator = "cpu"
46+
3347 args = load_example_args (
3448 model_name ,
3549 remove_prior = not use_atomref ,
@@ -43,6 +57,64 @@ def test_train(model_name, use_atomref, precision, tmpdir):
4357 num_rbf = 16 ,
4458 batch_size = 8 ,
4559 precision = precision ,
60+ num_workers = 0 ,
61+ )
62+ datamodule = DataModule (args , DummyDataset (has_atomref = use_atomref ))
63+
64+ prior = None
65+ if use_atomref :
66+ prior = getattr (priors , args ["prior_model" ])(dataset = datamodule .dataset )
67+ args ["prior_args" ] = prior .get_init_args ()
68+
69+ module = LNNP (args , prior_model = prior )
70+
71+ trainer = pl .Trainer (
72+ max_steps = 10 ,
73+ default_root_dir = tmpdir ,
74+ precision = args ["precision" ],
75+ inference_mode = False ,
76+ accelerator = accelerator ,
77+ num_nodes = 1 ,
78+ devices = 1 ,
79+ use_distributed_sampler = False ,
80+ )
81+ trainer .fit (module , datamodule )
82+ trainer .test (module , datamodule )
83+
84+
85+ @mark .parametrize ("model_name" , models .__all_models__ )
86+ @mark .parametrize ("use_atomref" , [True , False ])
87+ @mark .parametrize ("precision" , [32 , 64 ])
88+ def test_dummy_train (model_name , use_atomref , precision , tmpdir ):
89+ import torch
90+
91+ torch .set_num_threads (1 )
92+
93+ accelerator = "auto"
94+ if os .getenv ("CPU_TRAIN" , "false" ) == "true" :
95+ # OSX MPS backend runs out of memory on Github Actions
96+ torch .set_default_device ("cpu" )
97+ accelerator = "cpu"
98+
99+ extra_args = {}
100+ if model_name != "tensornet" :
101+ extra_args ["num_heads" ] = 2
102+
103+ args = load_example_args (
104+ model_name ,
105+ remove_prior = not use_atomref ,
106+ train_size = 0.05 ,
107+ val_size = 0.01 ,
108+ test_size = 0.01 ,
109+ log_dir = tmpdir ,
110+ derivative = True ,
111+ embedding_dimension = 2 ,
112+ num_layers = 1 ,
113+ num_rbf = 4 ,
114+ batch_size = 2 ,
115+ precision = precision ,
116+ num_workers = 0 ,
117+ ** extra_args ,
46118 )
47119 datamodule = DataModule (args , DummyDataset (has_atomref = use_atomref ))
48120
@@ -53,6 +125,15 @@ def test_train(model_name, use_atomref, precision, tmpdir):
53125
54126 module = LNNP (args , prior_model = prior )
55127
56- trainer = pl .Trainer (max_steps = 10 , default_root_dir = tmpdir , precision = args ["precision" ],inference_mode = False )
128+ trainer = pl .Trainer (
129+ max_steps = 10 ,
130+ default_root_dir = tmpdir ,
131+ precision = args ["precision" ],
132+ inference_mode = False ,
133+ accelerator = accelerator ,
134+ num_nodes = 1 ,
135+ devices = 1 ,
136+ use_distributed_sampler = False ,
137+ )
57138 trainer .fit (module , datamodule )
58139 trainer .test (module , datamodule )
0 commit comments