Skip to content

Commit af04abb

Browse files
committed
tests: add tests for additions to state and mixin.
1 parent e05dfb8 commit af04abb

File tree

2 files changed

+144
-2
lines changed

2 files changed

+144
-2
lines changed

tests/test_state.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import typing
22
from dataclasses import asdict
33

4+
import pytest
45
import torch
56

67
from torch_sim.integrators import MDState
78
from torch_sim.state import (
9+
DeformGradMixin,
810
SimState,
911
_normalize_batch_indices,
1012
_pop_states,
@@ -457,3 +459,143 @@ def test_row_vector_cell(si_sim_state: SimState) -> None:
457459

458460
# Test consistency of getter after setting
459461
assert torch.allclose(si_sim_state.row_vector_cell, new_cell.transpose(-2, -1))
462+
463+
464+
def test_column_vector_cell(si_sim_state: SimState) -> None:
465+
"""Test the column_vector_cell property getter and setter."""
466+
# Test getter - should return cell directly since it's already in column vector format
467+
original_cell = si_sim_state.cell.clone()
468+
column_vector = si_sim_state.column_vector_cell
469+
assert torch.allclose(column_vector, original_cell)
470+
471+
# Test setter - should update cell directly
472+
new_cell = torch.randn_like(original_cell)
473+
si_sim_state.column_vector_cell = new_cell
474+
assert torch.allclose(si_sim_state.cell, new_cell)
475+
476+
# Test consistency of getter after setting
477+
assert torch.allclose(si_sim_state.column_vector_cell, new_cell)
478+
479+
480+
class DeformState(SimState, DeformGradMixin):
481+
"""Test class that combines SimState with DeformGradMixin."""
482+
483+
def __init__(
484+
self,
485+
*args,
486+
velocities: torch.Tensor | None = None,
487+
reference_cell: torch.Tensor | None = None,
488+
**kwargs,
489+
) -> None:
490+
super().__init__(*args, **kwargs)
491+
self.velocities = velocities
492+
self.reference_cell = reference_cell
493+
494+
495+
@pytest.fixture
496+
def deform_grad_state(device: torch.device) -> DeformState:
497+
"""Create a test state with deformation gradient support."""
498+
499+
positions = torch.randn(10, 3, device=device)
500+
masses = torch.ones(10, device=device)
501+
velocities = torch.randn(10, 3, device=device)
502+
reference_cell = torch.eye(3, device=device).unsqueeze(0)
503+
current_cell = 2 * reference_cell
504+
505+
return DeformState(
506+
positions=positions,
507+
masses=masses,
508+
cell=current_cell,
509+
pbc=True,
510+
atomic_numbers=torch.ones(10, device=device, dtype=torch.long),
511+
velocities=velocities,
512+
reference_cell=reference_cell,
513+
)
514+
515+
516+
def test_deform_grad_momenta(deform_grad_state: DeformState) -> None:
517+
"""Test momenta calculation in DeformGradMixin."""
518+
expected_momenta = deform_grad_state.velocities * deform_grad_state.masses.unsqueeze(
519+
-1
520+
)
521+
assert torch.allclose(deform_grad_state.momenta, expected_momenta)
522+
523+
524+
def test_deform_grad_reference_cell(deform_grad_state: DeformState) -> None:
525+
"""Test reference cell getter/setter in DeformGradMixin."""
526+
original_ref_cell = deform_grad_state.reference_cell.clone()
527+
528+
# Test getter
529+
assert torch.allclose(
530+
deform_grad_state.reference_row_vector_cell, original_ref_cell.transpose(-2, -1)
531+
)
532+
533+
# Test setter
534+
new_ref_cell = 3 * torch.eye(3, device=deform_grad_state.device).unsqueeze(0)
535+
deform_grad_state.reference_row_vector_cell = new_ref_cell.transpose(-2, -1)
536+
assert torch.allclose(deform_grad_state.reference_cell, new_ref_cell)
537+
538+
539+
def test_deform_grad_uniform(deform_grad_state: DeformState) -> None:
540+
"""Test deformation gradient calculation for uniform deformation."""
541+
# For 2x uniform expansion, deformation gradient should be 2x identity matrix
542+
deform_grad = deform_grad_state.deform_grad()
543+
expected = 2 * torch.eye(3, device=deform_grad_state.device).unsqueeze(0)
544+
assert torch.allclose(deform_grad, expected)
545+
546+
547+
def test_deform_grad_non_uniform(device: torch.device) -> None:
548+
"""Test deformation gradient calculation for non-uniform deformation."""
549+
reference_cell = torch.eye(3, device=device).unsqueeze(0)
550+
current_cell = torch.tensor(
551+
[[[2.0, 0.1, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 1.8]]], device=device
552+
)
553+
554+
state = DeformState(
555+
positions=torch.randn(10, 3, device=device),
556+
masses=torch.ones(10, device=device),
557+
cell=current_cell,
558+
pbc=True,
559+
atomic_numbers=torch.ones(10, device=device, dtype=torch.long),
560+
velocities=torch.randn(10, 3, device=device),
561+
reference_cell=reference_cell,
562+
)
563+
564+
deform_grad = state.deform_grad()
565+
# Verify that deformation gradient correctly transforms reference cell to current cell
566+
reconstructed_cell = torch.matmul(reference_cell, deform_grad.transpose(-2, -1))
567+
assert torch.allclose(reconstructed_cell, current_cell)
568+
569+
570+
def test_deform_grad_batched(device: torch.device) -> None:
571+
"""Test deformation gradient calculation with batched states."""
572+
batch_size = 3
573+
n_atoms = 10
574+
575+
reference_cell = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
576+
current_cell = torch.stack(
577+
[
578+
2.0 * torch.eye(3, device=device), # Uniform expansion
579+
torch.eye(3, device=device), # No deformation
580+
0.5 * torch.eye(3, device=device), # Uniform compression
581+
]
582+
)
583+
584+
state = DeformState(
585+
positions=torch.randn(n_atoms * batch_size, 3, device=device),
586+
masses=torch.ones(n_atoms * batch_size, device=device),
587+
cell=current_cell,
588+
pbc=True,
589+
atomic_numbers=torch.ones(n_atoms * batch_size, device=device, dtype=torch.long),
590+
velocities=torch.randn(n_atoms * batch_size, 3, device=device),
591+
reference_cell=reference_cell,
592+
batch=torch.repeat_interleave(torch.arange(batch_size, device=device), n_atoms),
593+
)
594+
595+
deform_grad = state.deform_grad()
596+
assert deform_grad.shape == (batch_size, 3, 3)
597+
598+
expected_factors = torch.tensor([2.0, 1.0, 0.5], device=device)
599+
for i in range(batch_size):
600+
expected = expected_factors[i] * torch.eye(3, device=device)
601+
assert torch.allclose(deform_grad[i], expected)

torch_sim/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class SimState:
6868
masses (torch.Tensor): Atomic masses with shape (n_atoms,)
6969
cell (torch.Tensor): Unit cell vectors with shape (n_batches, 3, 3).
7070
Note that we use a column vector convention, i.e. the cell vectors are
71-
stored as `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` as opposed to
72-
the row vector convention `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]`
71+
stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to
72+
the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`
7373
used by ASE.
7474
pbc (bool): Boolean indicating whether to use periodic boundary conditions
7575
atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)

0 commit comments

Comments
 (0)