Skip to content

Commit 320677d

Browse files
committed
test: round trip all the formats
1 parent af04abb commit 320677d

File tree

4 files changed

+55
-90
lines changed

4 files changed

+55
-90
lines changed

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,17 @@ def si_double_sim_state(si_atoms: Atoms, device: torch.device, dtype: torch.dtyp
278278
return atoms_to_state([si_atoms, si_atoms], device, dtype)
279279

280280

281+
@pytest.fixture
282+
def mixed_double_sim_state(
283+
ar_supercell_sim_state: SimState, si_sim_state: SimState
284+
) -> SimState:
285+
"""Create a batched state from ar_fcc_sim_state."""
286+
return concatenate_states(
287+
[ar_supercell_sim_state, si_sim_state],
288+
device=ar_supercell_sim_state.device,
289+
)
290+
291+
281292
@pytest.fixture
282293
def unbatched_lj_model(
283294
device: torch.device, dtype: torch.dtype

tests/test_io.py

Lines changed: 42 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from typing import Any
23

34
import pytest
@@ -200,109 +201,63 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None:
200201

201202

202203
@pytest.mark.parametrize(
203-
"sim_state_name",
204-
[
205-
"ar_supercell_sim_state",
206-
"si_sim_state",
207-
"ti_sim_state",
208-
"sio2_sim_state",
209-
"fe_supercell_sim_state",
210-
"cu_sim_state",
211-
],
204+
("sim_state_name", "conversion_functions"),
205+
list(
206+
itertools.product(
207+
[
208+
"ar_supercell_sim_state",
209+
"si_sim_state",
210+
"ti_sim_state",
211+
"sio2_sim_state",
212+
"fe_supercell_sim_state",
213+
"cu_sim_state",
214+
"ar_double_sim_state",
215+
"mixed_double_sim_state",
216+
# TODO: round trip benzene/non-pbc systems
217+
],
218+
[
219+
(state_to_atoms, atoms_to_state),
220+
(state_to_structures, structures_to_state),
221+
(state_to_phonopy, phonopy_to_state),
222+
],
223+
)
224+
),
212225
)
213-
def test_state_to_atoms_round_trip(
226+
def test_state_round_trip(
214227
sim_state_name: str,
228+
conversion_functions: tuple,
215229
request: pytest.FixtureRequest,
216230
device: torch.device,
217231
dtype: torch.dtype,
218232
) -> None:
219-
"""Test round-trip conversion from SimState -> Atoms -> SimState.
233+
"""Test round-trip conversion from SimState through various formats and back.
220234
221235
Args:
222236
sim_state_name: Name of the sim_state fixture to test
237+
conversion_functions: Tuple of (to_format, from_format) conversion functions
223238
request: Pytest fixture request object to get dynamic fixtures
224239
device: Device to run tests on
225240
dtype: Data type to use
226241
"""
227242
# Get the sim_state fixture dynamically using the name
228243
sim_state: SimState = request.getfixturevalue(sim_state_name)
244+
to_format_fn, from_format_fn = conversion_functions
245+
unique_batches = torch.unique(sim_state.batch)
229246

230-
# First convert to atoms
231-
atoms_list = state_to_atoms(sim_state)
232-
assert len(atoms_list) == 1, f"Expected single system for {sim_state_name}"
247+
# Convert to intermediate format
248+
intermediate_format = to_format_fn(sim_state)
249+
assert len(intermediate_format) == len(unique_batches)
233250

234-
# Then convert back to state
235-
round_trip_state = atoms_to_state(atoms_list, device, dtype)
251+
# Convert back to state
252+
round_trip_state: SimState = from_format_fn(intermediate_format, device, dtype)
236253

237254
# Check that all properties match
238-
assert torch.allclose(
239-
sim_state.positions,
240-
round_trip_state.positions,
241-
)
242-
assert torch.allclose(
243-
sim_state.cell,
244-
round_trip_state.cell,
245-
)
246-
assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers), (
247-
f"Atomic numbers mismatch for {sim_state_name}"
248-
)
249-
assert torch.allclose(
250-
sim_state.masses,
251-
round_trip_state.masses,
252-
)
253-
assert torch.all(sim_state.batch == round_trip_state.batch), (
254-
f"Batch indices mismatch for {sim_state_name}"
255-
)
256-
assert sim_state.pbc == round_trip_state.pbc, f"PBC mismatch for {sim_state_name}"
257-
258-
259-
@pytest.mark.parametrize(
260-
"atoms_name",
261-
[
262-
"ar_atoms",
263-
"cu_atoms",
264-
"fe_atoms",
265-
"ti_atoms",
266-
"si_atoms",
267-
"sio2_atoms",
268-
],
269-
)
270-
def test_atoms_to_state_round_trip(
271-
atoms_name: str,
272-
request: pytest.FixtureRequest,
273-
device: torch.device,
274-
dtype: torch.dtype,
275-
) -> None:
276-
"""Test round-trip conversion from Atoms -> SimState -> Atoms.
277-
278-
Args:
279-
atoms_name: Name of the atoms fixture to test
280-
request: Pytest fixture request object to get dynamic fixtures
281-
device: Device to run tests on
282-
dtype: Data type to use
283-
"""
284-
# Get the atoms fixture dynamically using the name
285-
atoms: Atoms = request.getfixturevalue(atoms_name)
286-
287-
# First convert to state
288-
sim_state = atoms_to_state(atoms, device, dtype)
289-
290-
# Then convert back to atoms
291-
round_trip_atoms = state_to_atoms(sim_state)[0] # Get first system
292-
293-
# Check that all properties match
294-
assert torch.allclose(
295-
torch.tensor(atoms.positions, device=device, dtype=dtype),
296-
torch.tensor(round_trip_atoms.positions, device=device, dtype=dtype),
297-
)
298-
assert torch.allclose(
299-
torch.tensor(atoms.cell[:], device=device, dtype=dtype),
300-
torch.tensor(round_trip_atoms.cell[:], device=device, dtype=dtype),
301-
)
302-
assert (atoms.numbers == round_trip_atoms.numbers).all(), (
303-
f"Atomic numbers mismatch for {atoms_name}"
304-
)
305-
assert (atoms.get_masses() == round_trip_atoms.get_masses()).all(), (
306-
f"Masses mismatch for {atoms_name}"
307-
)
308-
assert (atoms.pbc == round_trip_atoms.pbc).all(), f"PBC mismatch for {atoms_name}"
255+
assert torch.allclose(sim_state.positions, round_trip_state.positions)
256+
assert torch.allclose(sim_state.cell, round_trip_state.cell)
257+
assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers)
258+
assert torch.all(sim_state.batch == round_trip_state.batch)
259+
assert sim_state.pbc == round_trip_state.pbc
260+
261+
if isinstance(intermediate_format[0], Atoms):
262+
# TODO: the round trip for pmg and phonopy masses is not exact.
263+
assert torch.allclose(sim_state.masses, round_trip_state.masses)

torch_sim/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def structures_to_state(
287287

288288
# Stack all properties
289289
cell = torch.tensor(
290-
np.stack([s.lattice.matrix for s in struct_list]), dtype=dtype, device=device
290+
np.stack([s.lattice.matrix.T for s in struct_list]), dtype=dtype, device=device
291291
)
292292
positions = torch.tensor(
293293
np.concatenate([s.cart_coords for s in struct_list]), dtype=dtype, device=device
@@ -373,7 +373,7 @@ def phonopy_to_state(
373373
device=device,
374374
)
375375
cell = torch.tensor(
376-
np.stack([a.cell for a in phonopy_atoms_list]), dtype=dtype, device=device
376+
np.stack([a.cell.T for a in phonopy_atoms_list]), dtype=dtype, device=device
377377
)
378378

379379
# Create batch indices using repeat_interleave

torch_sim/state.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
]
4848

4949

50-
# TODO: change later on
5150
@dataclass
5251
class SimState:
5352
"""State representation for atomistic systems with batched operations support.

0 commit comments

Comments
 (0)