Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def model_path(tmp_path_factory: pytest.TempPathFactory) -> str:


@pytest.fixture
def fairchem_model(model_path: str, device: torch.device) -> FairChemModel:
def fairchem_model_pbc(model_path: str, device: torch.device) -> FairChemModel:
cpu = device.type == "cpu"
return FairChemModel(
model=model_path,
cpu=cpu,
seed=0,
pbc=True,
)


Expand All @@ -41,28 +42,20 @@ def ocp_calculator(model_path: str) -> OCPCalculator:
return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0)


test_fairchem_ocp_consistency = make_model_calculator_consistency_test(
test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test(
test_name="fairchem_ocp",
model_fixture_name="fairchem_model",
model_fixture_name="fairchem_model_pbc",
calculator_fixture_name="ocp_calculator",
sim_state_names=consistency_test_simstate_fixtures,
sim_state_names=consistency_test_simstate_fixtures[:-1],
rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models
atol=5e-4,
)

# TODO: add test for non-PBC model

# fairchem batching is broken on CPU, do not replicate this skipping
# logic in other models tests
# @pytest.mark.skipif(
# not torch.cuda.is_available(),
# reason="Batching does not work properly on CPU for FAIRchem",
# )
# def test_validate_model_outputs(
# fairchem_model: FairChemModel, device: torch.device
# ) -> None:
# validate_model_outputs(fairchem_model, device, torch.float32)


# logic in other models tests. This is due to issues with how the models
# handle supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428)
test_fairchem_ocp_model_outputs = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Batching does not work properly on CPU for FAIRchem",
Expand Down
40 changes: 32 additions & 8 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@

Attributes:
neighbor_list_fn (Callable | None): Function to compute neighbor lists
r_max (float): Maximum cutoff radius for atomic interactions in Ångström
config (dict): Complete model configuration dictionary
trainer: FairChem trainer object that contains the model
data_object (Batch): Data object containing system information
Expand Down Expand Up @@ -108,9 +107,9 @@
trainer: str | None = None,
cpu: bool = False,
seed: int | None = None,
r_max: float | None = None, # noqa: ARG002
dtype: torch.dtype | None = None,
compute_stress: bool = False,
pbc: bool = True,
) -> None:
"""Initialize the FairChemModel with specified configuration.

Expand All @@ -128,9 +127,9 @@
trainer (str | None): Name of trainer class to use
cpu (bool): Whether to use CPU instead of GPU for computation
seed (int | None): Random seed for reproducibility
r_max (float | None): Maximum cutoff radius (overrides model default)
dtype (torch.dtype | None): Data type to use for computation
compute_stress (bool): Whether to compute stress tensor
pbc (bool): Whether to use periodic boundary conditions

Raises:
RuntimeError: If both model_name and model are specified
Expand All @@ -150,6 +149,7 @@
self._compute_stress = compute_stress
self._compute_forces = True
self._memory_scales_with = "n_atoms"
self.pbc = pbc

if model_name is not None:
if model is not None:
Expand Down Expand Up @@ -215,6 +215,14 @@
)

if "backbone" in config["model"]:
if config["model"]["backbone"]["use_pbc"] != pbc:
print(

Check warning on line 219 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L218-L219

Added lines #L218 - L219 were not covered by tests
f"WARNING: PBC mismatch between model and state. "
"The model loaded was trained with"
f"PBC={config['model']['backbone']['use_pbc']} "
f"and you are using PBC={pbc}."
)
config["model"]["backbone"]["use_pbc"] = pbc

Check warning on line 225 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L225

Added line #L225 was not covered by tests
config["model"]["backbone"]["use_pbc_single"] = False
if dtype is not None:
try:
Expand All @@ -224,14 +232,26 @@
{"dtype": _DTYPE_DICT[dtype]}
)
except KeyError:
print("dtype not found in backbone, using default float32")
print(

Check warning on line 235 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L235

Added line #L235 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)
else:
if config["model"]["use_pbc"] != pbc:
print(

Check warning on line 240 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L240

Added line #L240 was not covered by tests
f"WARNING: PBC mismatch between model and state. "
f"The model loaded was trained with"
f"PBC={config['model']['use_pbc']} "
f"and you are using PBC={pbc}."
)
config["model"]["use_pbc"] = pbc
config["model"]["use_pbc_single"] = False
if dtype is not None:
try:
config["model"].update({"dtype": _DTYPE_DICT[dtype]})
except KeyError:
print("dtype not found in backbone, using default dtype")
print(

Check warning on line 252 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L252

Added line #L252 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)

### backwards compatibility with OCP v<2.0
config = update_config(config)
Expand All @@ -257,11 +277,9 @@
inference_only=True,
)

self.trainer.model = self.trainer.model.eval()

if dtype is not None:
# Convert model parameters to specified dtype
self.trainer.model = self.trainer.model.to(dtype=self.dtype)
self.trainer.model = self.trainer.model.to(dtype=self._dtype)

Check warning on line 282 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L282

Added line #L282 was not covered by tests

if model is not None:
self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint)
Expand Down Expand Up @@ -335,6 +353,12 @@
if state.batch is None:
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)

if self.pbc != state.pbc:
raise ValueError(

Check warning on line 357 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L357

Added line #L357 was not covered by tests
"PBC mismatch between model and state. "
"For FairChemModel PBC needs to be defined in the model class."
)

natoms = torch.bincount(state.batch)
pbc = torch.tensor(
[state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool
Expand Down