Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0ac4a85
fea: add mattersim, revert sevennet to use reference neighbor list, f…
CompRhys Apr 5, 2025
7288111
fix: cell convention in mace and fairchem. revert fe unit cell to avo…
CompRhys Apr 5, 2025
51b9b4b
fea: create a test factory
CompRhys Apr 5, 2025
7fa161a
wip: parameterize the elastic tests
CompRhys Apr 5, 2025
62a3305
fix: address #113
CompRhys Apr 5, 2025
02bdc07
tests: move more atoms to conftest. rename supercells, default in unb…
CompRhys Apr 5, 2025
7dd7b3a
fea: mace no longer needs to touch default dtype, doc: describe the c…
CompRhys Apr 5, 2025
ab4ab6e
fix: casting Z to tkwargs breaks mace models with agnesi transforms
CompRhys Apr 5, 2025
072d726
fea: more consolidation of mace tests
CompRhys Apr 5, 2025
6a4bb11
fea: add a row vector cell property and setter to reduce number of tr…
CompRhys Apr 5, 2025
e05dfb8
fix: fix the elastic tests using row_vector_cell math, updates the fr…
CompRhys Apr 6, 2025
af04abb
tests: add tests for additions to state and mixin.
CompRhys Apr 6, 2025
320677d
test: round trip all the formats
CompRhys Apr 6, 2025
d09453a
fea: allow fairchem pbc
CompRhys Apr 6, 2025
cabe28e
fix: add device when creating tensors for sevennet data dict
abhijeetgangan Apr 6, 2025
393700f
reorganize fixtures to reduce redundancy and dependencies on unbatche…
orionarcher Apr 7, 2025
e859298
rewrite tests to use new fixtures
orionarcher Apr 7, 2025
04b835a
change back to cpu tests
orionarcher Apr 7, 2025
fa98362
lint
orionarcher Apr 7, 2025
82e730d
fix validate model fix
orionarcher Apr 7, 2025
869d962
fea: revert sevennet to vesin
CompRhys Apr 7, 2025
e0825c4
fea: add a rattled structure to test forces
CompRhys Apr 7, 2025
d8e9f96
tests: move test_elastic to batched code despite single states to bet…
CompRhys Apr 8, 2025
3009a3f
fea: fix row/column convention for classical potenials
CompRhys Apr 8, 2025
b1b8293
tests: parameterize again
CompRhys Apr 8, 2025
1a188d0
fea: make the rattled structure fixed
CompRhys Apr 8, 2025
d2301a7
fea: change rattled seed and make sure everything is in eval mode
CompRhys Apr 8, 2025
a5ae3b8
Merge branch 'fix-classical-row-column' into mattersim
CompRhys Apr 8, 2025
58e4d67
test: increase rtol for mace on rattled structure
CompRhys Apr 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ jobs:
- { python: "3.11", resolution: highest }
- { python: "3.12", resolution: lowest-direct }
model:
- { name: fairchem, test_path: "tests/models/test_fairchem.py" }
- { name: mace, test_path: "tests/models/test_mace.py" }
- { name: mace, test_path: "tests/test_elastic.py" }
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
- { name: orb, test_path: "tests/models/test_orb.py" }
- { name: sevenn, test_path: "tests/models/test_sevennet.py" }
- { name: fairchem, test_path: "tests/models/test_fairchem.py" }
runs-on: ${{ matrix.os }}

steps:
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"html_image",
]

autodoc_mock_imports = ["mace", "fairchem", "orb", "sevennet"]
autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet"]

# use type hints
autodoc_typehints = "description"
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ test = [
"pytest>=8",
]
mace = ["mace-torch>=0.3.11"]
sevenn = ["sevenn>=0.11.0"]
mattersim = ["mattersim>=0.1.2"]

orb = [
"orb-models@git+https://github.com/orbital-materials/orb-models#egg=637a98d49cfb494e2491a457d9bbd28311fecf21",
]
sevenn = ["sevenn>=0.11.0"]
docs = [
"autodoc_pydantic==2.2.0",
"furo==2024.8.6",
Expand Down Expand Up @@ -134,6 +136,7 @@ docstring-code-format = true

[tool.codespell]
check-filenames = true
ignore-words-list = ["convertor"]

[tool.pytest]
addopts = ["--cov-report=term-missing", "--cov=torch_sim", "-v"]
Expand Down
Empty file added tests/__init__.py
Empty file.
178 changes: 154 additions & 24 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from pathlib import Path
from typing import Any

import ase.spacegroup
import pytest
import torch
from ase import Atoms
from ase.build import bulk, molecule
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

from torch_sim.io import atoms_to_state
from torch_sim.io import atoms_to_state, state_to_atoms
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.state import SimState, concatenate_states
from torch_sim.trajectory import TrajectoryReporter
Expand All @@ -23,25 +24,53 @@ def device() -> torch.device:


@pytest.fixture
def cu_atoms() -> Any:
def dtype() -> torch.dtype:
return torch.float64


@pytest.fixture
def ar_atoms() -> Atoms:
"""Create a face-centered cubic (FCC) Argon structure."""
return bulk("Ar", "fcc", a=5.26, cubic=True)


@pytest.fixture
def cu_atoms() -> Atoms:
"""Create crystalline copper using ASE."""
return bulk("Cu", "fcc", a=3.58, cubic=True)


@pytest.fixture
def ti_atoms() -> Any:
def fe_atoms() -> Atoms:
"""Create crystalline iron using ASE."""
return bulk("Fe", "fcc", a=5.26, cubic=True)


@pytest.fixture
def ti_atoms() -> Atoms:
"""Create crystalline titanium using ASE."""
return bulk("Ti", "hcp", a=2.94, c=4.64)


@pytest.fixture
def si_atoms() -> Any:
def si_atoms() -> Atoms:
"""Create crystalline silicon using ASE."""
return bulk("Si", "diamond", a=5.43, cubic=True)


@pytest.fixture
def benzene_atoms() -> Any:
def sio2_atoms() -> Atoms:
"""Create an alpha-quartz SiO2 system for testing."""
return ase.spacegroup.crystal(
symbols=["O", "Si"],
basis=[[0.413, 0.2711, 0.2172], [0.4673, 0, 0.3333]],
spacegroup=152,
cellpar=[4.9019, 4.9019, 5.3988, 90, 90, 120],
)


@pytest.fixture
def benzene_atoms() -> Atoms:
"""Create benzene using ASE."""
return molecule("C6H6")

Expand Down Expand Up @@ -88,29 +117,53 @@ def si_phonopy_atoms() -> Any:


@pytest.fixture
def si_sim_state(si_atoms: Any, device: torch.device) -> Any:
def cu_sim_state(cu_atoms: Any, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a basic state from si_structure."""
return atoms_to_state(si_atoms, device, torch.float64)
return atoms_to_state(cu_atoms, device, dtype)


@pytest.fixture
def fe_fcc_sim_state(device: torch.device) -> Any:
fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([4, 4, 4])
return atoms_to_state(fe_atoms, device, torch.float64)
def ti_sim_state(ti_atoms: Any, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a basic state from si_structure."""
return atoms_to_state(ti_atoms, device, dtype)


@pytest.fixture
def si_double_sim_state(si_atoms: Atoms, device: torch.device) -> Any:
def si_sim_state(si_atoms: Any, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a basic state from si_structure."""
return atoms_to_state([si_atoms, si_atoms], device, torch.float64)
return atoms_to_state(si_atoms, device, dtype)


@pytest.fixture
def ar_sim_state(device: torch.device) -> SimState:
"""Create a face-centered cubic (FCC) Argon structure."""
# Create FCC Ar using ASE, with 4x4x4 supercell
ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2])
return atoms_to_state(ar_atoms, device, torch.float64)
def sio2_sim_state(sio2_atoms: Any, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a basic state from si_structure."""
return atoms_to_state(sio2_atoms, device, dtype)


@pytest.fixture
def benzene_sim_state(
benzene_atoms: Any, device: torch.device, dtype: torch.dtype
) -> Any:
"""Create a basic state from benzene_atoms."""
return atoms_to_state(benzene_atoms, device, dtype)


@pytest.fixture
def fe_fcc_sim_state(fe_atoms: Atoms, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a face-centered cubic (FCC) iron structure with 4x4x4 supercell."""
return atoms_to_state(fe_atoms.repeat([4, 4, 4]), device, dtype)


@pytest.fixture
def si_double_sim_state(si_atoms: Atoms, device: torch.device, dtype: torch.dtype) -> Any:
"""Create a basic state from si_structure."""
return atoms_to_state([si_atoms, si_atoms], device, dtype)


@pytest.fixture
def ar_sim_state(ar_atoms: Atoms, device: torch.device, dtype: torch.dtype) -> SimState:
"""Create a face-centered cubic (FCC) Argon structure with 2x2x2 supercell."""
return atoms_to_state(ar_atoms.repeat([2, 2, 2]), device, dtype)


@pytest.fixture
Expand All @@ -120,41 +173,49 @@ def ar_double_sim_state(ar_sim_state: SimState) -> SimState:


@pytest.fixture
def unbatched_lj_model(device: torch.device) -> UnbatchedLennardJonesModel:
def unbatched_lj_model(
device: torch.device, dtype: torch.dtype
) -> UnbatchedLennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return UnbatchedLennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=torch.float64,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def lj_model(device: torch.device) -> LennardJonesModel:
def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return LennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=torch.float64,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def torchsim_trajectory(si_sim_state: SimState, lj_model: Any, tmp_path: Path):
def torchsim_trajectory(
si_sim_state: SimState,
lj_model: Any,
tmp_path: Path,
device: torch.device,
dtype: torch.dtype,
):
"""Test NVE integration conserves energy."""
# Initialize integrator
kT = torch.tensor(300.0) # Temperature in K
dt = torch.tensor(0.001) # Small timestep for stability
kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K
dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability

state, update_fn = nve(
**asdict(si_sim_state),
Expand All @@ -173,3 +234,72 @@ def torchsim_trajectory(si_sim_state: SimState, lj_model: Any, tmp_path: Path):
yield reporter.trajectory

reporter.close()


def make_model_calculator_consistency_test(
test_name: str,
model_fixture_name: str,
calculator_fixture_name: str,
sim_state_names: list[str],
rtol: float = 1e-5,
atol: float = 1e-5,
):
"""Factory function to create model-calculator consistency tests.

Args:
test_name: Name of the test (used in the function name and messages)
model_fixture_name: Name of the model fixture
calculator_fixture_name: Name of the calculator fixture
sim_state_names: List of sim_state fixture names to test
rtol: Relative tolerance for numerical comparisons
atol: Absolute tolerance for numerical comparisons
"""

@pytest.mark.parametrize("sim_state_name", sim_state_names)
def test_model_calculator_consistency(
sim_state_name: str,
request: pytest.FixtureRequest,
device: torch.device,
dtype: torch.dtype,
) -> None:
"""Test consistency between model and calculator implementations."""
# Get the model and calculator fixtures dynamically
model = request.getfixturevalue(model_fixture_name)
calculator = request.getfixturevalue(calculator_fixture_name)

# Get the sim_state fixture dynamically using the name
sim_state = request.getfixturevalue(sim_state_name).to(device, dtype)

# Set up ASE calculator
atoms = state_to_atoms(sim_state)[0]
atoms.calc = calculator

# Get model results
model_results = model(sim_state)

# Get calculator results
calc_forces = torch.tensor(
atoms.get_forces(),
device=device,
dtype=model_results["forces"].dtype,
)

# Test consistency with specified tolerances
torch.testing.assert_close(
model_results["energy"].item(),
atoms.get_potential_energy(),
rtol=rtol,
atol=atol,
msg=f"Energy mismatch for {sim_state_name}",
)
torch.testing.assert_close(
model_results["forces"],
calc_forces,
rtol=rtol,
atol=atol,
msg=f"Forces mismatch for {sim_state_name}",
)

# Rename the function to include the test name
test_model_calculator_consistency.__name__ = f"test_{test_name}_consistency"
return test_model_calculator_consistency
56 changes: 15 additions & 41 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest
import torch
from ase.build import bulk

from torch_sim.io import atoms_to_state
from tests.conftest import make_model_calculator_consistency_test
from torch_sim.models.interface import validate_model_outputs
from torch_sim.state import SimState


try:
Expand All @@ -25,14 +23,6 @@ def model_path(tmp_path_factory: pytest.TempPathFactory) -> str:
)


@pytest.fixture
def si_system(dtype: torch.dtype, device: torch.device) -> SimState:
# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43)

return atoms_to_state([si_dc], device, dtype)


@pytest.fixture
def fairchem_model(model_path: str, device: torch.device) -> FairChemModel:
cpu = device.type == "cpu"
Expand All @@ -48,36 +38,20 @@ def ocp_calculator(model_path: str) -> OCPCalculator:
return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0)


def test_fairchem_ocp_consistency(
fairchem_model: FairChemModel,
ocp_calculator: OCPCalculator,
device: torch.device,
) -> None:
# Set up ASE calculator
si_dc = bulk("Si", "diamond", a=5.43)
si_dc.calc = ocp_calculator

si_state = atoms_to_state([si_dc], device, torch.float32)
# Get FairChem results
fairchem_results = fairchem_model.forward(si_state)

# Get OCP results
ocp_forces = torch.tensor(
si_dc.get_forces(),
device=device,
dtype=fairchem_results["forces"].dtype,
)

# Test consistency with reasonable tolerances
torch.testing.assert_close(
fairchem_results["energy"].item(),
si_dc.get_potential_energy(),
rtol=1e-2,
atol=1e-2,
)
torch.testing.assert_close(
fairchem_results["forces"], ocp_forces, rtol=1e-2, atol=1e-2
)
test_fairchem_ocp_consistency = make_model_calculator_consistency_test(
test_name="fairchem_ocp",
model_fixture_name="fairchem_model",
calculator_fixture_name="ocp_calculator",
sim_state_names=[
"cu_sim_state",
"ti_sim_state",
"si_sim_state",
"sio2_sim_state",
# "benzene_sim_state", # TODO: Turn on when #111 fixed
],
rtol=1e-2,
atol=1e-2,
)


# fairchem batching is broken on CPU, do not replicate this skipping
Expand Down
Loading
Loading