Skip to content

Commit 66c5feb

Browse files
FairChemModel Updates: PBC handling, test on OMat24 pre-trained model (#126)
* fix: revert a few changes and add a proper fix for pbc handling with caution * test: remove benezene from fairchem tests. * nit: be consistent about using dtype vs _dtype and device vs _device in models * nit: be consistent about using compute_stress and compute_forces vs _compute_stress and _compute_forces * adds disable_amp, fixes previous use_pbc, adds omat24 model test * download huggingface_hub in tests * increase the tolerance * increase the tolerance * test larger machine for CI * revert to smaller machine, use omat24 model only for batch test --------- Co-authored-by: Rhys Goodall <[email protected]>
1 parent caa5423 commit 66c5feb

File tree

15 files changed

+172
-114
lines changed

15 files changed

+172
-114
lines changed

.github/workflows/test.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ jobs:
8888
- name: Set up uv
8989
uses: astral-sh/setup-uv@v2
9090

91+
- name: Install HuggingFace Hub CLI
92+
run: uv pip install huggingface_hub --system
93+
94+
- name: HuggingFace Hub Login
95+
env:
96+
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
97+
run: |
98+
if [ -n "$HF_TOKEN" ]; then
99+
huggingface-cli login --token "$HF_TOKEN"
100+
else
101+
echo "HF_TOKEN is not set. Skipping login."
102+
fi
103+
91104
- name: Install fairchem repository and dependencies
92105
if: ${{ matrix.model.name == 'fairchem' }}
93106
run: |

tests/models/test_fairchem.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pytest
24
import torch
35

@@ -11,6 +13,7 @@
1113
try:
1214
from fairchem.core import OCPCalculator
1315
from fairchem.core.models.model_registry import model_name_to_local_file
16+
from huggingface_hub.utils._auth import get_token
1417

1518
from torch_sim.models.fairchem import FairChemModel
1619

@@ -19,55 +22,87 @@
1922

2023

2124
@pytest.fixture(scope="session")
22-
def model_path(tmp_path_factory: pytest.TempPathFactory) -> str:
25+
def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str:
2326
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
24-
return model_name_to_local_file(
25-
"EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=str(tmp_path)
27+
model_name = "EquiformerV2-31M-S2EF-OC20-All+MD"
28+
return model_name_to_local_file(model_name, local_cache=str(tmp_path))
29+
30+
31+
@pytest.fixture
32+
def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel:
33+
cpu = device.type == "cpu"
34+
return FairChemModel(
35+
model=model_path_oc20,
36+
cpu=cpu,
37+
seed=0,
38+
pbc=True,
2639
)
2740

2841

2942
@pytest.fixture
30-
def fairchem_model(model_path: str, device: torch.device) -> FairChemModel:
43+
def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel:
3144
cpu = device.type == "cpu"
3245
return FairChemModel(
33-
model=model_path,
46+
model=model_path_oc20,
3447
cpu=cpu,
3548
seed=0,
49+
pbc=False,
3650
)
3751

3852

53+
if get_token():
54+
55+
@pytest.fixture(scope="session")
56+
def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str:
57+
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
58+
model_name = "EquiformerV2-31M-OMAT24-MP-sAlex"
59+
return model_name_to_local_file(model_name, local_cache=str(tmp_path))
60+
61+
@pytest.fixture
62+
def eqv2_omat24_model_pbc(
63+
model_path_omat24: str, device: torch.device
64+
) -> FairChemModel:
65+
cpu = device.type == "cpu"
66+
return FairChemModel(
67+
model=model_path_omat24,
68+
cpu=cpu,
69+
seed=0,
70+
pbc=True,
71+
)
72+
73+
3974
@pytest.fixture
40-
def ocp_calculator(model_path: str) -> OCPCalculator:
41-
return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0)
75+
def ocp_calculator(model_path_oc20: str) -> OCPCalculator:
76+
return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0)
4277

4378

44-
test_fairchem_ocp_consistency = make_model_calculator_consistency_test(
79+
test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test(
4580
test_name="fairchem_ocp",
46-
model_fixture_name="fairchem_model",
81+
model_fixture_name="eqv2_oc20_model_pbc",
4782
calculator_fixture_name="ocp_calculator",
48-
sim_state_names=consistency_test_simstate_fixtures,
49-
rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models
83+
sim_state_names=consistency_test_simstate_fixtures[:-1],
84+
rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
5085
atol=5e-4,
5186
)
5287

88+
test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test(
89+
test_name="fairchem_non_pbc_benzene",
90+
model_fixture_name="eqv2_oc20_model_non_pbc",
91+
calculator_fixture_name="ocp_calculator",
92+
sim_state_names=["benzene_sim_state"],
93+
rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
94+
atol=5e-4,
95+
)
5396

54-
# fairchem batching is broken on CPU, do not replicate this skipping
55-
# logic in other models tests
56-
# @pytest.mark.skipif(
57-
# not torch.cuda.is_available(),
58-
# reason="Batching does not work properly on CPU for FAIRchem",
59-
# )
60-
# def test_validate_model_outputs(
61-
# fairchem_model: FairChemModel, device: torch.device
62-
# ) -> None:
63-
# validate_model_outputs(fairchem_model, device, torch.float32)
6497

98+
# Skip this test due to issues with how the older models
99+
# handled supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428)
65100

66101
test_fairchem_ocp_model_outputs = pytest.mark.skipif(
67-
not torch.cuda.is_available(),
68-
reason="Batching does not work properly on CPU for FAIRchem",
102+
os.environ.get("HF_TOKEN") is None,
103+
reason="Issues in graph construction of older models",
69104
)(
70105
make_validate_model_outputs_test(
71-
model_fixture_name="fairchem_model",
106+
model_fixture_name="eqv2_omat24_model_pbc",
72107
)
73108
)

torch_sim/models/fairchem.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class FairChemModel(torch.nn.Module, ModelInterface):
7676
7777
Attributes:
7878
neighbor_list_fn (Callable | None): Function to compute neighbor lists
79-
r_max (float): Maximum cutoff radius for atomic interactions in Ångström
8079
config (dict): Complete model configuration dictionary
8180
trainer: FairChem trainer object that contains the model
8281
data_object (Batch): Data object containing system information
@@ -108,9 +107,10 @@ def __init__( # noqa: C901, PLR0915
108107
trainer: str | None = None,
109108
cpu: bool = False,
110109
seed: int | None = None,
111-
r_max: float | None = None, # noqa: ARG002
112110
dtype: torch.dtype | None = None,
113111
compute_stress: bool = False,
112+
pbc: bool = True,
113+
disable_amp: bool = True,
114114
) -> None:
115115
"""Initialize the FairChemModel with specified configuration.
116116
@@ -128,10 +128,10 @@ def __init__( # noqa: C901, PLR0915
128128
trainer (str | None): Name of trainer class to use
129129
cpu (bool): Whether to use CPU instead of GPU for computation
130130
seed (int | None): Random seed for reproducibility
131-
r_max (float | None): Maximum cutoff radius (overrides model default)
132131
dtype (torch.dtype | None): Data type to use for computation
133132
compute_stress (bool): Whether to compute stress tensor
134-
133+
pbc (bool): Whether to use periodic boundary conditions
134+
disable_amp (bool): Whether to disable AMP
135135
Raises:
136136
RuntimeError: If both model_name and model are specified
137137
NotImplementedError: If local_cache is not set when model_name is used
@@ -150,6 +150,7 @@ def __init__( # noqa: C901, PLR0915
150150
self._compute_stress = compute_stress
151151
self._compute_forces = True
152152
self._memory_scales_with = "n_atoms"
153+
self.pbc = pbc
153154

154155
if model_name is not None:
155156
if model is not None:
@@ -215,6 +216,7 @@ def __init__( # noqa: C901, PLR0915
215216
)
216217

217218
if "backbone" in config["model"]:
219+
config["model"]["backbone"]["use_pbc"] = pbc
218220
config["model"]["backbone"]["use_pbc_single"] = False
219221
if dtype is not None:
220222
try:
@@ -224,14 +226,19 @@ def __init__( # noqa: C901, PLR0915
224226
{"dtype": _DTYPE_DICT[dtype]}
225227
)
226228
except KeyError:
227-
print("dtype not found in backbone, using default float32")
229+
print(
230+
"WARNING: dtype not found in backbone, using default model dtype"
231+
)
228232
else:
233+
config["model"]["use_pbc"] = pbc
229234
config["model"]["use_pbc_single"] = False
230235
if dtype is not None:
231236
try:
232237
config["model"].update({"dtype": _DTYPE_DICT[dtype]})
233238
except KeyError:
234-
print("dtype not found in backbone, using default dtype")
239+
print(
240+
"WARNING: dtype not found in backbone, using default model dtype"
241+
)
235242

236243
### backwards compatibility with OCP v<2.0
237244
config = update_config(config)
@@ -257,8 +264,6 @@ def __init__( # noqa: C901, PLR0915
257264
inference_only=True,
258265
)
259266

260-
self.trainer.model = self.trainer.model.eval()
261-
262267
if dtype is not None:
263268
# Convert model parameters to specified dtype
264269
self.trainer.model = self.trainer.model.to(dtype=self.dtype)
@@ -275,6 +280,9 @@ def __init__( # noqa: C901, PLR0915
275280
else:
276281
self.trainer.set_seed(seed)
277282

283+
if disable_amp:
284+
self.trainer.scaler = None
285+
278286
self.implemented_properties = list(self.config["outputs"])
279287

280288
self._device = self.trainer.device
@@ -335,6 +343,12 @@ def forward(self, state: SimState | StateDict) -> dict:
335343
if state.batch is None:
336344
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)
337345

346+
if self.pbc != state.pbc:
347+
raise ValueError(
348+
"PBC mismatch between model and state. "
349+
"For FairChemModel PBC needs to be defined in the model class."
350+
)
351+
338352
natoms = torch.bincount(state.batch)
339353
pbc = torch.tensor(
340354
[state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool
@@ -350,9 +364,9 @@ def forward(self, state: SimState | StateDict) -> dict:
350364
pbc=pbc,
351365
)
352366

353-
if self._dtype is not None:
354-
self.data_object.pos = self.data_object.pos.to(self._dtype)
355-
self.data_object.cell = self.data_object.cell.to(self._dtype)
367+
if self.dtype is not None:
368+
self.data_object.pos = self.data_object.pos.to(self.dtype)
369+
self.data_object.cell = self.data_object.cell.to(self.dtype)
356370

357371
predictions = self.trainer.predict(
358372
self.data_object, per_image=False, disable_tqdm=True

torch_sim/models/graphpes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,15 @@ def __init__(
154154
model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type]
155155
),
156156
)
157-
self._gp_model = _model.to(device=self._device, dtype=self._dtype)
157+
self._gp_model = _model.to(device=self.device, dtype=self.dtype)
158158

159159
self._compute_forces = compute_forces
160160
self._compute_stress = compute_stress
161161

162162
self._properties: list[PropertyKey] = ["energy"]
163-
if self._compute_forces:
163+
if self.compute_forces:
164164
self._properties.append("forces")
165-
if self._compute_stress:
165+
if self.compute_stress:
166166
self._properties.append("stress")
167167

168168
if self._gp_model.cutoff.item() < 0.5:

torch_sim/models/lennard_jones.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,9 @@ def __init__(
138138
self.use_neighbor_list = use_neighbor_list
139139

140140
# Convert parameters to tensors
141-
self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device)
142-
self.cutoff = torch.tensor(
143-
cutoff or 2.5 * sigma, dtype=dtype, device=self._device
144-
)
145-
self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device)
141+
self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device)
142+
self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device)
143+
self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device)
146144

147145
def unbatched_forward(
148146
self,
@@ -209,7 +207,7 @@ def unbatched_forward(
209207
pbc=pbc,
210208
)
211209
# Mask out self-interactions
212-
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device)
210+
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
213211
distances = distances.masked_fill(mask, float("inf"))
214212
# Apply cutoff
215213
mask = distances < self.cutoff
@@ -233,14 +231,14 @@ def unbatched_forward(
233231

234232
if self.per_atom_energies:
235233
atom_energies = torch.zeros(
236-
positions.shape[0], dtype=self._dtype, device=self._device
234+
positions.shape[0], dtype=self.dtype, device=self.device
237235
)
238236
# Each atom gets half of the pair energy
239237
atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
240238
atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
241239
results["energies"] = atom_energies
242240

243-
if self._compute_forces or self._compute_stress:
241+
if self.compute_forces or self.compute_stress:
244242
# Calculate forces and apply cutoff
245243
pair_forces = lennard_jones_pair_force(
246244
distances, sigma=self.sigma, epsilon=self.epsilon
@@ -250,15 +248,15 @@ def unbatched_forward(
250248
# Project forces along displacement vectors
251249
force_vectors = (pair_forces / distances)[:, None] * dr_vec
252250

253-
if self._compute_forces:
251+
if self.compute_forces:
254252
# Initialize forces tensor
255253
forces = torch.zeros_like(positions)
256254
# Add force contributions (f_ij on i, -f_ij on j)
257255
forces.index_add_(0, mapping[0], -force_vectors)
258256
forces.index_add_(0, mapping[1], force_vectors)
259257
results["forces"] = forces
260258

261-
if self._compute_stress and cell is not None:
259+
if self.compute_stress and cell is not None:
262260
# Compute stress tensor
263261
stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
264262
volume = torch.abs(torch.linalg.det(cell))
@@ -268,8 +266,8 @@ def unbatched_forward(
268266
if self.per_atom_stresses:
269267
atom_stresses = torch.zeros(
270268
(state.positions.shape[0], 3, 3),
271-
dtype=self._dtype,
272-
device=self._device,
269+
dtype=self.dtype,
270+
device=self.device,
273271
)
274272
atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
275273
atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)

torch_sim/models/mace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def __init__(
160160
self.model = model.to(self._device)
161161
self.model = self.model.eval()
162162

163-
if self._dtype is not None:
164-
self.model = self.model.to(dtype=self._dtype)
163+
if self.dtype is not None:
164+
self.model = self.model.to(dtype=self.dtype)
165165

166166
if enable_cueq:
167167
print("Converting models to CuEq for acceleration")
@@ -334,8 +334,8 @@ def forward( # noqa: C901
334334
unit_shifts=unit_shifts,
335335
shifts=shifts_list,
336336
),
337-
compute_force=self._compute_forces,
338-
compute_stress=self._compute_stress,
337+
compute_force=self.compute_forces,
338+
compute_stress=self.compute_stress,
339339
)
340340

341341
results = {}
@@ -348,13 +348,13 @@ def forward( # noqa: C901
348348
results["energy"] = torch.zeros(self.n_systems, device=self.device)
349349

350350
# Process forces
351-
if self._compute_forces:
351+
if self.compute_forces:
352352
forces = out["forces"]
353353
if forces is not None:
354354
results["forces"] = forces.detach()
355355

356356
# Process stress
357-
if self._compute_stress:
357+
if self.compute_stress:
358358
stress = out["stress"]
359359
if stress is not None:
360360
results["stress"] = stress.detach()

torch_sim/models/mattersim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(
8585
self.model = model.to(self._device)
8686
self.model = self.model.eval()
8787

88-
if self._dtype is not None:
89-
self.model = self.model.to(dtype=self._dtype)
88+
if self.dtype is not None:
89+
self.model = self.model.to(dtype=self.dtype)
9090

9191
model_args = self.model.model.model_args
9292
self.two_body_cutoff = model_args["cutoff"]

0 commit comments

Comments
 (0)