|
1 | 1 | import copy |
2 | | -import functools |
3 | 2 | from dataclasses import fields |
4 | 3 | from typing import get_args |
5 | 4 |
|
6 | 5 | import pytest |
7 | 6 | import torch |
8 | | -from ase.filters import FrechetCellFilter |
9 | | -from ase.optimize import FIRE |
10 | | -from mace.calculators import MACECalculator |
11 | 7 |
|
12 | | -import torch_sim as ts |
13 | | -from torch_sim.io import state_to_atoms |
14 | | -from torch_sim.models.mace import MaceModel |
15 | 8 | from torch_sim.optimizers import ( |
16 | 9 | FireState, |
17 | 10 | FrechetCellFIREState, |
@@ -915,110 +908,3 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: |
915 | 908 | f"Energy for batch {step} doesn't match position only optimization: " |
916 | 909 | f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" |
917 | 910 | ) |
918 | | - |
919 | | - |
920 | | -def test_torchsim_frechet_cell_fire_vs_ase_mace( |
921 | | - rattled_sio2_sim_state: ts.state.SimState, |
922 | | - torchsim_mace_mpa: MaceModel, |
923 | | - ase_mace_mpa: MACECalculator, |
924 | | -) -> None: |
925 | | - """Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter |
926 | | - using MACE-MPA-0. |
927 | | -
|
928 | | - This test ensures that the custom Frechet Cell FIRE implementation behaves comparably |
929 | | - to the established ASE equivalent when using a MACE force field. |
930 | | - It checks for consistency in final energies, forces, positions, and cell parameters. |
931 | | - """ |
932 | | - # Use float64 for consistency with the MACE model fixture and for precision |
933 | | - dtype = torch.float64 |
934 | | - device = torchsim_mace_mpa.device # Use device from the model |
935 | | - |
936 | | - # --- Setup Initial State with float64 --- |
937 | | - # Deepcopy to avoid modifying the fixture state for other tests |
938 | | - initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) |
939 | | - |
940 | | - # Ensure grads are enabled for both positions and cell for optimization |
941 | | - initial_state.positions = initial_state.positions.detach().requires_grad_( |
942 | | - requires_grad=True |
943 | | - ) |
944 | | - initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) |
945 | | - |
946 | | - n_steps = 20 # Number of optimization steps |
947 | | - force_tol = 0.02 # Convergence criterion for forces |
948 | | - |
949 | | - # --- Run torch-sim Frechet Cell FIRE with MACE model --- |
950 | | - # Use functools.partial to set md_flavor for the frechet_cell_fire optimizer |
951 | | - torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire") |
952 | | - |
953 | | - custom_opt_state = ts.optimize( |
954 | | - system=initial_state, |
955 | | - model=torchsim_mace_mpa, |
956 | | - optimizer=torch_sim_optimizer, |
957 | | - max_steps=n_steps, |
958 | | - convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), |
959 | | - ) |
960 | | - |
961 | | - # --- Setup ASE System with native MACE calculator --- |
962 | | - # Convert initial SimState to ASE Atoms object |
963 | | - ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list |
964 | | - ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator |
965 | | - |
966 | | - # --- Run ASE FIRE with FrechetCellFilter --- |
967 | | - # Apply FrechetCellFilter for cell optimization |
968 | | - filtered_ase_atoms = FrechetCellFilter(ase_atoms) |
969 | | - ase_optimizer = FIRE(filtered_ase_atoms) |
970 | | - |
971 | | - # Run ASE optimization |
972 | | - ase_optimizer.run(fmax=force_tol, steps=n_steps) |
973 | | - |
974 | | - # --- Compare Results --- |
975 | | - final_custom_energy = custom_opt_state.energy.item() |
976 | | - final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item() |
977 | | - final_custom_positions = custom_opt_state.positions.detach() |
978 | | - # Ensure cell is in row vector format and squeezed for comparison |
979 | | - final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach() |
980 | | - |
981 | | - final_ase_energy = ase_atoms.get_potential_energy() |
982 | | - ase_forces_raw = ase_atoms.get_forces() |
983 | | - if ase_forces_raw is not None: |
984 | | - final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype) |
985 | | - final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() |
986 | | - else: |
987 | | - # Should not happen if calculator ran and produced forces |
988 | | - final_ase_forces_max = float("nan") |
989 | | - |
990 | | - final_ase_positions = torch.tensor( |
991 | | - ase_atoms.get_positions(), device=device, dtype=dtype |
992 | | - ) |
993 | | - final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) |
994 | | - |
995 | | - # Compare energies (looser tolerance for ML potentials due to potential minor |
996 | | - # numerical differences) |
997 | | - energy_diff = abs(final_custom_energy - final_ase_energy) |
998 | | - assert energy_diff < 5e-2, ( |
999 | | - f"Final energies differ significantly after {n_steps} steps: " |
1000 | | - f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " |
1001 | | - f"Diff={energy_diff:.2e}" |
1002 | | - ) |
1003 | | - |
1004 | | - # Report forces for diagnostics |
1005 | | - print( |
1006 | | - f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, " |
1007 | | - f"ASE={final_ase_forces_max:.4f}" |
1008 | | - ) |
1009 | | - |
1010 | | - # Compare positions (average displacement, looser tolerance) |
1011 | | - avg_displacement = ( |
1012 | | - torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() |
1013 | | - ) |
1014 | | - assert avg_displacement < 1.0, ( |
1015 | | - f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})" |
1016 | | - ) |
1017 | | - |
1018 | | - # Compare cell matrices (Frobenius norm, looser tolerance) |
1019 | | - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() |
1020 | | - assert cell_diff < 1.0, ( |
1021 | | - f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" |
1022 | | - f"\nTorch-sim Cell:\n{final_custom_cell}" |
1023 | | - f"\nASE Cell:\n{final_ase_cell}" |
1024 | | - ) |
0 commit comments