@@ -76,7 +76,6 @@ class FairChemModel(torch.nn.Module, ModelInterface):
7676
7777 Attributes:
7878 neighbor_list_fn (Callable | None): Function to compute neighbor lists
79- r_max (float): Maximum cutoff radius for atomic interactions in Ångström
8079 config (dict): Complete model configuration dictionary
8180 trainer: FairChem trainer object that contains the model
8281 data_object (Batch): Data object containing system information
@@ -108,9 +107,10 @@ def __init__( # noqa: C901, PLR0915
108107 trainer : str | None = None ,
109108 cpu : bool = False ,
110109 seed : int | None = None ,
111- r_max : float | None = None , # noqa: ARG002
112110 dtype : torch .dtype | None = None ,
113111 compute_stress : bool = False ,
112+ pbc : bool = True ,
113+ disable_amp : bool = True ,
114114 ) -> None :
115115 """Initialize the FairChemModel with specified configuration.
116116
@@ -128,10 +128,10 @@ def __init__( # noqa: C901, PLR0915
128128 trainer (str | None): Name of trainer class to use
129129 cpu (bool): Whether to use CPU instead of GPU for computation
130130 seed (int | None): Random seed for reproducibility
131- r_max (float | None): Maximum cutoff radius (overrides model default)
132131 dtype (torch.dtype | None): Data type to use for computation
133132 compute_stress (bool): Whether to compute stress tensor
134-
133+ pbc (bool): Whether to use periodic boundary conditions
134+ disable_amp (bool): Whether to disable AMP
135135 Raises:
136136 RuntimeError: If both model_name and model are specified
137137 NotImplementedError: If local_cache is not set when model_name is used
@@ -150,6 +150,7 @@ def __init__( # noqa: C901, PLR0915
150150 self ._compute_stress = compute_stress
151151 self ._compute_forces = True
152152 self ._memory_scales_with = "n_atoms"
153+ self .pbc = pbc
153154
154155 if model_name is not None :
155156 if model is not None :
@@ -215,6 +216,7 @@ def __init__( # noqa: C901, PLR0915
215216 )
216217
217218 if "backbone" in config ["model" ]:
219+ config ["model" ]["backbone" ]["use_pbc" ] = pbc
218220 config ["model" ]["backbone" ]["use_pbc_single" ] = False
219221 if dtype is not None :
220222 try :
@@ -224,14 +226,19 @@ def __init__( # noqa: C901, PLR0915
224226 {"dtype" : _DTYPE_DICT [dtype ]}
225227 )
226228 except KeyError :
227- print ("dtype not found in backbone, using default float32" )
229+ print (
230+ "WARNING: dtype not found in backbone, using default model dtype"
231+ )
228232 else :
233+ config ["model" ]["use_pbc" ] = pbc
229234 config ["model" ]["use_pbc_single" ] = False
230235 if dtype is not None :
231236 try :
232237 config ["model" ].update ({"dtype" : _DTYPE_DICT [dtype ]})
233238 except KeyError :
234- print ("dtype not found in backbone, using default dtype" )
239+ print (
240+ "WARNING: dtype not found in backbone, using default model dtype"
241+ )
235242
236243 ### backwards compatibility with OCP v<2.0
237244 config = update_config (config )
@@ -257,8 +264,6 @@ def __init__( # noqa: C901, PLR0915
257264 inference_only = True ,
258265 )
259266
260- self .trainer .model = self .trainer .model .eval ()
261-
262267 if dtype is not None :
263268 # Convert model parameters to specified dtype
264269 self .trainer .model = self .trainer .model .to (dtype = self .dtype )
@@ -275,6 +280,9 @@ def __init__( # noqa: C901, PLR0915
275280 else :
276281 self .trainer .set_seed (seed )
277282
283+ if disable_amp :
284+ self .trainer .scaler = None
285+
278286 self .implemented_properties = list (self .config ["outputs" ])
279287
280288 self ._device = self .trainer .device
@@ -335,6 +343,12 @@ def forward(self, state: SimState | StateDict) -> dict:
335343 if state .batch is None :
336344 state .batch = torch .zeros (state .positions .shape [0 ], dtype = torch .int )
337345
346+ if self .pbc != state .pbc :
347+ raise ValueError (
348+ "PBC mismatch between model and state. "
349+ "For FairChemModel PBC needs to be defined in the model class."
350+ )
351+
338352 natoms = torch .bincount (state .batch )
339353 pbc = torch .tensor (
340354 [state .pbc , state .pbc , state .pbc ] * len (natoms ), dtype = torch .bool
@@ -350,9 +364,9 @@ def forward(self, state: SimState | StateDict) -> dict:
350364 pbc = pbc ,
351365 )
352366
353- if self ._dtype is not None :
354- self .data_object .pos = self .data_object .pos .to (self ._dtype )
355- self .data_object .cell = self .data_object .cell .to (self ._dtype )
367+ if self .dtype is not None :
368+ self .data_object .pos = self .data_object .pos .to (self .dtype )
369+ self .data_object .cell = self .data_object .cell .to (self .dtype )
356370
357371 predictions = self .trainer .predict (
358372 self .data_object , per_image = False , disable_tqdm = True
0 commit comments