Skip to content

Commit 914cbe3

Browse files
committed
on 2nd thought, keep test_torchsim_frechet_cell_fire_vs_ase_mace in a separate file (thanks @CompRhys)
1 parent 2fa6a7e commit 914cbe3

File tree

3 files changed

+121
-119
lines changed

3 files changed

+121
-119
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
--ignore=tests/models/test_sevennet.py \
4949
--ignore=tests/models/test_mattersim.py \
5050
--ignore=tests/models/test_metatensor.py \
51-
--ignore=tests/models/test_torchsim_vs_ase_fire_mace.py \
51+
--ignore=tests/test_optimizers_vs_ase.py \
5252
5353
- name: Upload coverage to Codecov
5454
uses: codecov/codecov-action@v5
@@ -69,10 +69,7 @@ jobs:
6969
- { name: graphpes, test_path: "tests/models/test_graphpes.py" }
7070
- { name: mace, test_path: "tests/models/test_mace.py" }
7171
- { name: mace, test_path: "tests/test_elastic.py" }
72-
- {
73-
name: mace,
74-
test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py",
75-
}
72+
- { name: mace, test_path: "tests/test_optimizers_vs_ase.py" }
7673
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
7774
- { name: metatensor, test_path: "tests/models/test_metatensor.py" }
7875
- { name: orb, test_path: "tests/models/test_orb.py" }

tests/test_optimizers.py

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
11
import copy
2-
import functools
32
from dataclasses import fields
43
from typing import get_args
54

65
import pytest
76
import torch
8-
from ase.filters import FrechetCellFilter
9-
from ase.optimize import FIRE
10-
from mace.calculators import MACECalculator
117

12-
import torch_sim as ts
13-
from torch_sim.io import state_to_atoms
14-
from torch_sim.models.mace import MaceModel
158
from torch_sim.optimizers import (
169
FireState,
1710
FrechetCellFIREState,
@@ -915,110 +908,3 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool:
915908
f"Energy for batch {step} doesn't match position only optimization: "
916909
f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}"
917910
)
918-
919-
920-
def test_torchsim_frechet_cell_fire_vs_ase_mace(
921-
rattled_sio2_sim_state: ts.state.SimState,
922-
torchsim_mace_mpa: MaceModel,
923-
ase_mace_mpa: MACECalculator,
924-
) -> None:
925-
"""Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter
926-
using MACE-MPA-0.
927-
928-
This test ensures that the custom Frechet Cell FIRE implementation behaves comparably
929-
to the established ASE equivalent when using a MACE force field.
930-
It checks for consistency in final energies, forces, positions, and cell parameters.
931-
"""
932-
# Use float64 for consistency with the MACE model fixture and for precision
933-
dtype = torch.float64
934-
device = torchsim_mace_mpa.device # Use device from the model
935-
936-
# --- Setup Initial State with float64 ---
937-
# Deepcopy to avoid modifying the fixture state for other tests
938-
initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device)
939-
940-
# Ensure grads are enabled for both positions and cell for optimization
941-
initial_state.positions = initial_state.positions.detach().requires_grad_(
942-
requires_grad=True
943-
)
944-
initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True)
945-
946-
n_steps = 20 # Number of optimization steps
947-
force_tol = 0.02 # Convergence criterion for forces
948-
949-
# --- Run torch-sim Frechet Cell FIRE with MACE model ---
950-
# Use functools.partial to set md_flavor for the frechet_cell_fire optimizer
951-
torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire")
952-
953-
custom_opt_state = ts.optimize(
954-
system=initial_state,
955-
model=torchsim_mace_mpa,
956-
optimizer=torch_sim_optimizer,
957-
max_steps=n_steps,
958-
convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol),
959-
)
960-
961-
# --- Setup ASE System with native MACE calculator ---
962-
# Convert initial SimState to ASE Atoms object
963-
ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list
964-
ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator
965-
966-
# --- Run ASE FIRE with FrechetCellFilter ---
967-
# Apply FrechetCellFilter for cell optimization
968-
filtered_ase_atoms = FrechetCellFilter(ase_atoms)
969-
ase_optimizer = FIRE(filtered_ase_atoms)
970-
971-
# Run ASE optimization
972-
ase_optimizer.run(fmax=force_tol, steps=n_steps)
973-
974-
# --- Compare Results ---
975-
final_custom_energy = custom_opt_state.energy.item()
976-
final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item()
977-
final_custom_positions = custom_opt_state.positions.detach()
978-
# Ensure cell is in row vector format and squeezed for comparison
979-
final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach()
980-
981-
final_ase_energy = ase_atoms.get_potential_energy()
982-
ase_forces_raw = ase_atoms.get_forces()
983-
if ase_forces_raw is not None:
984-
final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype)
985-
final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item()
986-
else:
987-
# Should not happen if calculator ran and produced forces
988-
final_ase_forces_max = float("nan")
989-
990-
final_ase_positions = torch.tensor(
991-
ase_atoms.get_positions(), device=device, dtype=dtype
992-
)
993-
final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype)
994-
995-
# Compare energies (looser tolerance for ML potentials due to potential minor
996-
# numerical differences)
997-
energy_diff = abs(final_custom_energy - final_ase_energy)
998-
assert energy_diff < 5e-2, (
999-
f"Final energies differ significantly after {n_steps} steps: "
1000-
f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, "
1001-
f"Diff={energy_diff:.2e}"
1002-
)
1003-
1004-
# Report forces for diagnostics
1005-
print(
1006-
f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, "
1007-
f"ASE={final_ase_forces_max:.4f}"
1008-
)
1009-
1010-
# Compare positions (average displacement, looser tolerance)
1011-
avg_displacement = (
1012-
torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item()
1013-
)
1014-
assert avg_displacement < 1.0, (
1015-
f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})"
1016-
)
1017-
1018-
# Compare cell matrices (Frobenius norm, looser tolerance)
1019-
cell_diff = torch.norm(final_custom_cell - final_ase_cell).item()
1020-
assert cell_diff < 1.0, (
1021-
f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})"
1022-
f"\nTorch-sim Cell:\n{final_custom_cell}"
1023-
f"\nASE Cell:\n{final_ase_cell}"
1024-
)

tests/test_optimizers_vs_ase.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import copy
2+
import functools
3+
4+
import torch
5+
from ase.filters import FrechetCellFilter
6+
from ase.optimize import FIRE
7+
from mace.calculators import MACECalculator
8+
9+
import torch_sim as ts
10+
from torch_sim.io import state_to_atoms
11+
from torch_sim.models.mace import MaceModel
12+
from torch_sim.optimizers import frechet_cell_fire
13+
14+
15+
def test_torchsim_frechet_cell_fire_vs_ase_mace(
16+
rattled_sio2_sim_state: ts.state.SimState,
17+
torchsim_mace_mpa: MaceModel,
18+
ase_mace_mpa: MACECalculator,
19+
) -> None:
20+
"""Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter
21+
using MACE-MPA-0.
22+
23+
This test ensures that the custom Frechet Cell FIRE implementation behaves comparably
24+
to the established ASE equivalent when using a MACE force field.
25+
It checks for consistency in final energies, forces, positions, and cell parameters.
26+
"""
27+
# Use float64 for consistency with the MACE model fixture and for precision
28+
dtype = torch.float64
29+
device = torchsim_mace_mpa.device # Use device from the model
30+
31+
# --- Setup Initial State with float64 ---
32+
# Deepcopy to avoid modifying the fixture state for other tests
33+
initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device)
34+
35+
# Ensure grads are enabled for both positions and cell for optimization
36+
initial_state.positions = initial_state.positions.detach().requires_grad_(
37+
requires_grad=True
38+
)
39+
initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True)
40+
41+
n_steps = 20 # Number of optimization steps
42+
force_tol = 0.02 # Convergence criterion for forces
43+
44+
# --- Run torch-sim Frechet Cell FIRE with MACE model ---
45+
# Use functools.partial to set md_flavor for the frechet_cell_fire optimizer
46+
torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire")
47+
48+
custom_opt_state = ts.optimize(
49+
system=initial_state,
50+
model=torchsim_mace_mpa,
51+
optimizer=torch_sim_optimizer,
52+
max_steps=n_steps,
53+
convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol),
54+
)
55+
56+
# --- Setup ASE System with native MACE calculator ---
57+
# Convert initial SimState to ASE Atoms object
58+
ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list
59+
ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator
60+
61+
# --- Run ASE FIRE with FrechetCellFilter ---
62+
# Apply FrechetCellFilter for cell optimization
63+
filtered_ase_atoms = FrechetCellFilter(ase_atoms)
64+
ase_optimizer = FIRE(filtered_ase_atoms)
65+
66+
# Run ASE optimization
67+
ase_optimizer.run(fmax=force_tol, steps=n_steps)
68+
69+
# --- Compare Results ---
70+
final_custom_energy = custom_opt_state.energy.item()
71+
final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item()
72+
final_custom_positions = custom_opt_state.positions.detach()
73+
# Ensure cell is in row vector format and squeezed for comparison
74+
final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach()
75+
76+
final_ase_energy = ase_atoms.get_potential_energy()
77+
ase_forces_raw = ase_atoms.get_forces()
78+
if ase_forces_raw is not None:
79+
final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype)
80+
final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item()
81+
else:
82+
# Should not happen if calculator ran and produced forces
83+
final_ase_forces_max = float("nan")
84+
85+
final_ase_positions = torch.tensor(
86+
ase_atoms.get_positions(), device=device, dtype=dtype
87+
)
88+
final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype)
89+
90+
# Compare energies (looser tolerance for ML potentials due to potential minor
91+
# numerical differences)
92+
energy_diff = abs(final_custom_energy - final_ase_energy)
93+
assert energy_diff < 5e-2, (
94+
f"Final energies differ significantly after {n_steps} steps: "
95+
f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, "
96+
f"Diff={energy_diff:.2e}"
97+
)
98+
99+
# Report forces for diagnostics
100+
print(
101+
f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, "
102+
f"ASE={final_ase_forces_max:.4f}"
103+
)
104+
105+
# Compare positions (average displacement, looser tolerance)
106+
avg_displacement = (
107+
torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item()
108+
)
109+
assert avg_displacement < 1.0, (
110+
f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})"
111+
)
112+
113+
# Compare cell matrices (Frobenius norm, looser tolerance)
114+
cell_diff = torch.norm(final_custom_cell - final_ase_cell).item()
115+
assert cell_diff < 1.0, (
116+
f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})"
117+
f"\nTorch-sim Cell:\n{final_custom_cell}"
118+
f"\nASE Cell:\n{final_ase_cell}"
119+
)

0 commit comments

Comments
 (0)