Skip to content

Commit ad950b0

Browse files
committed
define MaceUrls StrEnum to avoid breaking tests when "small" checkpoints get redirected in mace-torch
1 parent 914cbe3 commit ad950b0

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

tests/conftest.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Any, Final
1+
from enum import StrEnum
2+
from typing import TYPE_CHECKING, Any
23

34
import pytest
45
import torch
@@ -19,6 +20,13 @@
1920
from mace.calculators import MACECalculator
2021

2122

23+
class MaceUrls(StrEnum):
24+
mace_small = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model"
25+
mace_off_small = (
26+
"https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model"
27+
)
28+
29+
2230
@pytest.fixture
2331
def device() -> torch.device:
2432
return torch.device("cpu")
@@ -324,18 +332,13 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
324332
)
325333

326334

327-
MACE_CHECKPOINT_URL: Final[str] = (
328-
"https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model"
329-
)
330-
331-
332335
@pytest.fixture
333336
def ase_mace_mpa() -> "MACECalculator":
334337
"""Provides an ASE MACECalculator instance using mace_mp."""
335338
from mace.calculators.foundations_models import mace_mp
336339

337340
# Ensure dtype matches the one used in the torchsim fixture (float64)
338-
return mace_mp(model=MACE_CHECKPOINT_URL, default_dtype="float64")
341+
return mace_mp(model=MaceUrls.mace_small, default_dtype="float64")
339342

340343

341344
@pytest.fixture
@@ -346,7 +349,7 @@ def torchsim_mace_mpa() -> MaceModel:
346349
# Use float64 for potentially higher precision needed in optimization
347350
dtype = getattr(torch, dtype_str := "float64")
348351
raw_mace = mace_mp(
349-
model=MACE_CHECKPOINT_URL, return_raw_model=True, default_dtype=dtype_str
352+
model=MaceUrls.mace_small, return_raw_model=True, default_dtype=dtype_str
350353
)
351354
return MaceModel(
352355
model=raw_mace,

tests/models/test_mace.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ase.atoms import Atoms
44

55
import torch_sim as ts
6+
from tests.conftest import MaceUrls
67
from tests.models.conftest import (
78
consistency_test_simstate_fixtures,
89
make_model_calculator_consistency_test,
@@ -19,8 +20,8 @@
1920
pytest.skip("MACE not installed", allow_module_level=True)
2021

2122

22-
mace_model = mace_mp(model="small", return_raw_model=True)
23-
mace_off_model = mace_off(model="small", return_raw_model=True)
23+
mace_model = mace_mp(model=MaceUrls.mace_small, return_raw_model=True)
24+
mace_off_model = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True)
2425

2526

2627
@pytest.fixture
@@ -32,7 +33,7 @@ def dtype() -> torch.dtype:
3233
@pytest.fixture
3334
def ase_mace_calculator() -> MACECalculator:
3435
return mace_mp(
35-
model="small",
36+
model=MaceUrls.mace_small,
3637
device="cpu",
3738
default_dtype="float32",
3839
dispersion=False,
@@ -96,7 +97,7 @@ def benzene_system(
9697
@pytest.fixture
9798
def ase_mace_off_calculator() -> MACECalculator:
9899
return mace_off(
99-
model="small",
100+
model=MaceUrls.mace_off_small,
100101
device="cpu",
101102
default_dtype="float32",
102103
dispersion=False,

tests/unbatched/test_unbatched_mace.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ase.atoms import Atoms
44

55
import torch_sim as ts
6+
from tests.conftest import MaceUrls
67
from tests.unbatched.conftest import make_unbatched_model_calculator_consistency_test
78

89

@@ -15,8 +16,8 @@
1516
pytest.skip("MACE not installed", allow_module_level=True)
1617

1718

18-
mace_model = mace_mp(model="small", return_raw_model=True)
19-
mace_off_model = mace_off(model="small", return_raw_model=True)
19+
mace_model = mace_mp(model=MaceUrls.mace_small, return_raw_model=True)
20+
mace_off_model = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True)
2021

2122

2223
@pytest.fixture
@@ -28,7 +29,7 @@ def dtype() -> torch.dtype:
2829
@pytest.fixture
2930
def ase_mace_calculator() -> MACECalculator:
3031
return mace_mp(
31-
model="small",
32+
model=MaceUrls.mace_small,
3233
device="cpu",
3334
default_dtype="float32",
3435
dispersion=False,

0 commit comments

Comments
 (0)