|
| 1 | +import itertools |
1 | 2 | from typing import Any |
2 | 3 |
|
3 | 4 | import pytest |
@@ -200,109 +201,63 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: |
200 | 201 |
|
201 | 202 |
|
202 | 203 | @pytest.mark.parametrize( |
203 | | - "sim_state_name", |
204 | | - [ |
205 | | - "ar_supercell_sim_state", |
206 | | - "si_sim_state", |
207 | | - "ti_sim_state", |
208 | | - "sio2_sim_state", |
209 | | - "fe_supercell_sim_state", |
210 | | - "cu_sim_state", |
211 | | - ], |
| 204 | + ("sim_state_name", "conversion_functions"), |
| 205 | + list( |
| 206 | + itertools.product( |
| 207 | + [ |
| 208 | + "ar_supercell_sim_state", |
| 209 | + "si_sim_state", |
| 210 | + "ti_sim_state", |
| 211 | + "sio2_sim_state", |
| 212 | + "fe_supercell_sim_state", |
| 213 | + "cu_sim_state", |
| 214 | + "ar_double_sim_state", |
| 215 | + "mixed_double_sim_state", |
| 216 | + # TODO: round trip benzene/non-pbc systems |
| 217 | + ], |
| 218 | + [ |
| 219 | + (state_to_atoms, atoms_to_state), |
| 220 | + (state_to_structures, structures_to_state), |
| 221 | + (state_to_phonopy, phonopy_to_state), |
| 222 | + ], |
| 223 | + ) |
| 224 | + ), |
212 | 225 | ) |
213 | | -def test_state_to_atoms_round_trip( |
| 226 | +def test_state_round_trip( |
214 | 227 | sim_state_name: str, |
| 228 | + conversion_functions: tuple, |
215 | 229 | request: pytest.FixtureRequest, |
216 | 230 | device: torch.device, |
217 | 231 | dtype: torch.dtype, |
218 | 232 | ) -> None: |
219 | | - """Test round-trip conversion from SimState -> Atoms -> SimState. |
| 233 | + """Test round-trip conversion from SimState through various formats and back. |
220 | 234 |
|
221 | 235 | Args: |
222 | 236 | sim_state_name: Name of the sim_state fixture to test |
| 237 | + conversion_functions: Tuple of (to_format, from_format) conversion functions |
223 | 238 | request: Pytest fixture request object to get dynamic fixtures |
224 | 239 | device: Device to run tests on |
225 | 240 | dtype: Data type to use |
226 | 241 | """ |
227 | 242 | # Get the sim_state fixture dynamically using the name |
228 | 243 | sim_state: SimState = request.getfixturevalue(sim_state_name) |
| 244 | + to_format_fn, from_format_fn = conversion_functions |
| 245 | + unique_batches = torch.unique(sim_state.batch) |
229 | 246 |
|
230 | | - # First convert to atoms |
231 | | - atoms_list = state_to_atoms(sim_state) |
232 | | - assert len(atoms_list) == 1, f"Expected single system for {sim_state_name}" |
| 247 | + # Convert to intermediate format |
| 248 | + intermediate_format = to_format_fn(sim_state) |
| 249 | + assert len(intermediate_format) == len(unique_batches) |
233 | 250 |
|
234 | | - # Then convert back to state |
235 | | - round_trip_state = atoms_to_state(atoms_list, device, dtype) |
| 251 | + # Convert back to state |
| 252 | + round_trip_state: SimState = from_format_fn(intermediate_format, device, dtype) |
236 | 253 |
|
237 | 254 | # Check that all properties match |
238 | | - assert torch.allclose( |
239 | | - sim_state.positions, |
240 | | - round_trip_state.positions, |
241 | | - ) |
242 | | - assert torch.allclose( |
243 | | - sim_state.cell, |
244 | | - round_trip_state.cell, |
245 | | - ) |
246 | | - assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers), ( |
247 | | - f"Atomic numbers mismatch for {sim_state_name}" |
248 | | - ) |
249 | | - assert torch.allclose( |
250 | | - sim_state.masses, |
251 | | - round_trip_state.masses, |
252 | | - ) |
253 | | - assert torch.all(sim_state.batch == round_trip_state.batch), ( |
254 | | - f"Batch indices mismatch for {sim_state_name}" |
255 | | - ) |
256 | | - assert sim_state.pbc == round_trip_state.pbc, f"PBC mismatch for {sim_state_name}" |
257 | | - |
258 | | - |
259 | | -@pytest.mark.parametrize( |
260 | | - "atoms_name", |
261 | | - [ |
262 | | - "ar_atoms", |
263 | | - "cu_atoms", |
264 | | - "fe_atoms", |
265 | | - "ti_atoms", |
266 | | - "si_atoms", |
267 | | - "sio2_atoms", |
268 | | - ], |
269 | | -) |
270 | | -def test_atoms_to_state_round_trip( |
271 | | - atoms_name: str, |
272 | | - request: pytest.FixtureRequest, |
273 | | - device: torch.device, |
274 | | - dtype: torch.dtype, |
275 | | -) -> None: |
276 | | - """Test round-trip conversion from Atoms -> SimState -> Atoms. |
277 | | -
|
278 | | - Args: |
279 | | - atoms_name: Name of the atoms fixture to test |
280 | | - request: Pytest fixture request object to get dynamic fixtures |
281 | | - device: Device to run tests on |
282 | | - dtype: Data type to use |
283 | | - """ |
284 | | - # Get the atoms fixture dynamically using the name |
285 | | - atoms: Atoms = request.getfixturevalue(atoms_name) |
286 | | - |
287 | | - # First convert to state |
288 | | - sim_state = atoms_to_state(atoms, device, dtype) |
289 | | - |
290 | | - # Then convert back to atoms |
291 | | - round_trip_atoms = state_to_atoms(sim_state)[0] # Get first system |
292 | | - |
293 | | - # Check that all properties match |
294 | | - assert torch.allclose( |
295 | | - torch.tensor(atoms.positions, device=device, dtype=dtype), |
296 | | - torch.tensor(round_trip_atoms.positions, device=device, dtype=dtype), |
297 | | - ) |
298 | | - assert torch.allclose( |
299 | | - torch.tensor(atoms.cell[:], device=device, dtype=dtype), |
300 | | - torch.tensor(round_trip_atoms.cell[:], device=device, dtype=dtype), |
301 | | - ) |
302 | | - assert (atoms.numbers == round_trip_atoms.numbers).all(), ( |
303 | | - f"Atomic numbers mismatch for {atoms_name}" |
304 | | - ) |
305 | | - assert (atoms.get_masses() == round_trip_atoms.get_masses()).all(), ( |
306 | | - f"Masses mismatch for {atoms_name}" |
307 | | - ) |
308 | | - assert (atoms.pbc == round_trip_atoms.pbc).all(), f"PBC mismatch for {atoms_name}" |
| 255 | + assert torch.allclose(sim_state.positions, round_trip_state.positions) |
| 256 | + assert torch.allclose(sim_state.cell, round_trip_state.cell) |
| 257 | + assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) |
| 258 | + assert torch.all(sim_state.batch == round_trip_state.batch) |
| 259 | + assert sim_state.pbc == round_trip_state.pbc |
| 260 | + |
| 261 | + if isinstance(intermediate_format[0], Atoms): |
| 262 | + # TODO: the round trip for pmg and phonopy masses is not exact. |
| 263 | + assert torch.allclose(sim_state.masses, round_trip_state.masses) |
0 commit comments