|
1 | 1 | import typing |
2 | 2 | from dataclasses import asdict |
3 | 3 |
|
| 4 | +import pytest |
4 | 5 | import torch |
5 | 6 |
|
6 | 7 | from torch_sim.integrators import MDState |
7 | 8 | from torch_sim.state import ( |
| 9 | + DeformGradMixin, |
8 | 10 | SimState, |
9 | 11 | _normalize_batch_indices, |
10 | 12 | _pop_states, |
@@ -457,3 +459,143 @@ def test_row_vector_cell(si_sim_state: SimState) -> None: |
457 | 459 |
|
458 | 460 | # Test consistency of getter after setting |
459 | 461 | assert torch.allclose(si_sim_state.row_vector_cell, new_cell.transpose(-2, -1)) |
| 462 | + |
| 463 | + |
| 464 | +def test_column_vector_cell(si_sim_state: SimState) -> None: |
| 465 | + """Test the column_vector_cell property getter and setter.""" |
| 466 | + # Test getter - should return cell directly since it's already in column vector format |
| 467 | + original_cell = si_sim_state.cell.clone() |
| 468 | + column_vector = si_sim_state.column_vector_cell |
| 469 | + assert torch.allclose(column_vector, original_cell) |
| 470 | + |
| 471 | + # Test setter - should update cell directly |
| 472 | + new_cell = torch.randn_like(original_cell) |
| 473 | + si_sim_state.column_vector_cell = new_cell |
| 474 | + assert torch.allclose(si_sim_state.cell, new_cell) |
| 475 | + |
| 476 | + # Test consistency of getter after setting |
| 477 | + assert torch.allclose(si_sim_state.column_vector_cell, new_cell) |
| 478 | + |
| 479 | + |
| 480 | +class DeformState(SimState, DeformGradMixin): |
| 481 | + """Test class that combines SimState with DeformGradMixin.""" |
| 482 | + |
| 483 | + def __init__( |
| 484 | + self, |
| 485 | + *args, |
| 486 | + velocities: torch.Tensor | None = None, |
| 487 | + reference_cell: torch.Tensor | None = None, |
| 488 | + **kwargs, |
| 489 | + ) -> None: |
| 490 | + super().__init__(*args, **kwargs) |
| 491 | + self.velocities = velocities |
| 492 | + self.reference_cell = reference_cell |
| 493 | + |
| 494 | + |
| 495 | +@pytest.fixture |
| 496 | +def deform_grad_state(device: torch.device) -> DeformState: |
| 497 | + """Create a test state with deformation gradient support.""" |
| 498 | + |
| 499 | + positions = torch.randn(10, 3, device=device) |
| 500 | + masses = torch.ones(10, device=device) |
| 501 | + velocities = torch.randn(10, 3, device=device) |
| 502 | + reference_cell = torch.eye(3, device=device).unsqueeze(0) |
| 503 | + current_cell = 2 * reference_cell |
| 504 | + |
| 505 | + return DeformState( |
| 506 | + positions=positions, |
| 507 | + masses=masses, |
| 508 | + cell=current_cell, |
| 509 | + pbc=True, |
| 510 | + atomic_numbers=torch.ones(10, device=device, dtype=torch.long), |
| 511 | + velocities=velocities, |
| 512 | + reference_cell=reference_cell, |
| 513 | + ) |
| 514 | + |
| 515 | + |
| 516 | +def test_deform_grad_momenta(deform_grad_state: DeformState) -> None: |
| 517 | + """Test momenta calculation in DeformGradMixin.""" |
| 518 | + expected_momenta = deform_grad_state.velocities * deform_grad_state.masses.unsqueeze( |
| 519 | + -1 |
| 520 | + ) |
| 521 | + assert torch.allclose(deform_grad_state.momenta, expected_momenta) |
| 522 | + |
| 523 | + |
| 524 | +def test_deform_grad_reference_cell(deform_grad_state: DeformState) -> None: |
| 525 | + """Test reference cell getter/setter in DeformGradMixin.""" |
| 526 | + original_ref_cell = deform_grad_state.reference_cell.clone() |
| 527 | + |
| 528 | + # Test getter |
| 529 | + assert torch.allclose( |
| 530 | + deform_grad_state.reference_row_vector_cell, original_ref_cell.transpose(-2, -1) |
| 531 | + ) |
| 532 | + |
| 533 | + # Test setter |
| 534 | + new_ref_cell = 3 * torch.eye(3, device=deform_grad_state.device).unsqueeze(0) |
| 535 | + deform_grad_state.reference_row_vector_cell = new_ref_cell.transpose(-2, -1) |
| 536 | + assert torch.allclose(deform_grad_state.reference_cell, new_ref_cell) |
| 537 | + |
| 538 | + |
| 539 | +def test_deform_grad_uniform(deform_grad_state: DeformState) -> None: |
| 540 | + """Test deformation gradient calculation for uniform deformation.""" |
| 541 | + # For 2x uniform expansion, deformation gradient should be 2x identity matrix |
| 542 | + deform_grad = deform_grad_state.deform_grad() |
| 543 | + expected = 2 * torch.eye(3, device=deform_grad_state.device).unsqueeze(0) |
| 544 | + assert torch.allclose(deform_grad, expected) |
| 545 | + |
| 546 | + |
| 547 | +def test_deform_grad_non_uniform(device: torch.device) -> None: |
| 548 | + """Test deformation gradient calculation for non-uniform deformation.""" |
| 549 | + reference_cell = torch.eye(3, device=device).unsqueeze(0) |
| 550 | + current_cell = torch.tensor( |
| 551 | + [[[2.0, 0.1, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 1.8]]], device=device |
| 552 | + ) |
| 553 | + |
| 554 | + state = DeformState( |
| 555 | + positions=torch.randn(10, 3, device=device), |
| 556 | + masses=torch.ones(10, device=device), |
| 557 | + cell=current_cell, |
| 558 | + pbc=True, |
| 559 | + atomic_numbers=torch.ones(10, device=device, dtype=torch.long), |
| 560 | + velocities=torch.randn(10, 3, device=device), |
| 561 | + reference_cell=reference_cell, |
| 562 | + ) |
| 563 | + |
| 564 | + deform_grad = state.deform_grad() |
| 565 | + # Verify that deformation gradient correctly transforms reference cell to current cell |
| 566 | + reconstructed_cell = torch.matmul(reference_cell, deform_grad.transpose(-2, -1)) |
| 567 | + assert torch.allclose(reconstructed_cell, current_cell) |
| 568 | + |
| 569 | + |
| 570 | +def test_deform_grad_batched(device: torch.device) -> None: |
| 571 | + """Test deformation gradient calculation with batched states.""" |
| 572 | + batch_size = 3 |
| 573 | + n_atoms = 10 |
| 574 | + |
| 575 | + reference_cell = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1) |
| 576 | + current_cell = torch.stack( |
| 577 | + [ |
| 578 | + 2.0 * torch.eye(3, device=device), # Uniform expansion |
| 579 | + torch.eye(3, device=device), # No deformation |
| 580 | + 0.5 * torch.eye(3, device=device), # Uniform compression |
| 581 | + ] |
| 582 | + ) |
| 583 | + |
| 584 | + state = DeformState( |
| 585 | + positions=torch.randn(n_atoms * batch_size, 3, device=device), |
| 586 | + masses=torch.ones(n_atoms * batch_size, device=device), |
| 587 | + cell=current_cell, |
| 588 | + pbc=True, |
| 589 | + atomic_numbers=torch.ones(n_atoms * batch_size, device=device, dtype=torch.long), |
| 590 | + velocities=torch.randn(n_atoms * batch_size, 3, device=device), |
| 591 | + reference_cell=reference_cell, |
| 592 | + batch=torch.repeat_interleave(torch.arange(batch_size, device=device), n_atoms), |
| 593 | + ) |
| 594 | + |
| 595 | + deform_grad = state.deform_grad() |
| 596 | + assert deform_grad.shape == (batch_size, 3, 3) |
| 597 | + |
| 598 | + expected_factors = torch.tensor([2.0, 1.0, 0.5], device=device) |
| 599 | + for i in range(batch_size): |
| 600 | + expected = expected_factors[i] * torch.eye(3, device=device) |
| 601 | + assert torch.allclose(deform_grad[i], expected) |
0 commit comments