Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
Attributes:
neighbor_list_fn (Callable | None): Function to compute neighbor lists
r_max (float): Maximum cutoff radius for atomic interactions in Ångström
config (dict): Complete model configuration dictionary
trainer: FairChem trainer object that contains the model
data_object (Batch): Data object containing system information
Expand Down Expand Up @@ -108,9 +107,9 @@
trainer: str | None = None,
cpu: bool = False,
seed: int | None = None,
r_max: float | None = None, # noqa: ARG002
dtype: torch.dtype | None = None,
compute_stress: bool = False,
pbc: bool = True,
) -> None:
"""Initialize the FairChemModel with specified configuration.
Expand All @@ -128,9 +127,9 @@
trainer (str | None): Name of trainer class to use
cpu (bool): Whether to use CPU instead of GPU for computation
seed (int | None): Random seed for reproducibility
r_max (float | None): Maximum cutoff radius (overrides model default)
dtype (torch.dtype | None): Data type to use for computation
compute_stress (bool): Whether to compute stress tensor
pbc (bool): Whether to use periodic boundary conditions
Raises:
RuntimeError: If both model_name and model are specified
Expand All @@ -150,6 +149,7 @@
self._compute_stress = compute_stress
self._compute_forces = True
self._memory_scales_with = "n_atoms"
self.pbc = pbc

Check warning on line 152 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L152

Added line #L152 was not covered by tests

if model_name is not None:
if model is not None:
Expand Down Expand Up @@ -215,6 +215,14 @@
)

if "backbone" in config["model"]:
if config["model"]["backbone"]["use_pbc"] != pbc:
print(

Check warning on line 219 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L218-L219

Added lines #L218 - L219 were not covered by tests
f"WARNING: PBC mismatch between model and state. "
"The model loaded was trained with"
f"PBC={config['model']['backbone']['use_pbc']} "
f"and you are using PBC={pbc}."
)
config["model"]["backbone"]["use_pbc"] = pbc

Check warning on line 225 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L225

Added line #L225 was not covered by tests
config["model"]["backbone"]["use_pbc_single"] = False
if dtype is not None:
try:
Expand All @@ -224,14 +232,26 @@
{"dtype": _DTYPE_DICT[dtype]}
)
except KeyError:
print("dtype not found in backbone, using default float32")
print(

Check warning on line 235 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L235

Added line #L235 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)
else:
if config["model"]["use_pbc"] != pbc:
print(

Check warning on line 240 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L239-L240

Added lines #L239 - L240 were not covered by tests
f"WARNING: PBC mismatch between model and state. "
f"The model loaded was trained with"
f"PBC={config['model']['use_pbc']} "
f"and you are using PBC={pbc}."
)
config["model"]["use_pbc"] = pbc

Check warning on line 246 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L246

Added line #L246 was not covered by tests
config["model"]["use_pbc_single"] = False
if dtype is not None:
try:
config["model"].update({"dtype": _DTYPE_DICT[dtype]})
except KeyError:
print("dtype not found in backbone, using default dtype")
print(

Check warning on line 252 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L252

Added line #L252 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)

### backwards compatibility with OCP v<2.0
config = update_config(config)
Expand All @@ -257,11 +277,9 @@
inference_only=True,
)

self.trainer.model = self.trainer.model.eval()

if dtype is not None:
# Convert model parameters to specified dtype
self.trainer.model = self.trainer.model.to(dtype=self.dtype)
self.trainer.model = self.trainer.model.to(dtype=self._dtype)

Check warning on line 282 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L282

Added line #L282 was not covered by tests

if model is not None:
self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint)
Expand Down Expand Up @@ -338,6 +356,12 @@
if state.batch is None:
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)

if self.pbc != state.pbc:
raise ValueError(
"PBC mismatch between model and state. "
"For FairChemModel PBC needs to be defined in the model class."
)

Check warning on line 363 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L362-L363

Added lines #L362 - L363 were not covered by tests

natoms = torch.bincount(state.batch)
pbc = torch.tensor(
[state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool
Expand Down
Loading