|
1 | | -from dataclasses import asdict |
2 | | -from pathlib import Path |
3 | 1 | from typing import Any |
4 | 2 |
|
5 | 3 | import pytest |
|
13 | 11 | from torch_sim.io import atoms_to_state |
14 | 12 | from torch_sim.models.lennard_jones import LennardJonesModel |
15 | 13 | from torch_sim.state import SimState, concatenate_states |
16 | | -from torch_sim.trajectory import TrajectoryReporter |
17 | 14 | from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel |
18 | | -from torch_sim.unbatched.unbatched_integrators import nve |
19 | 15 |
|
20 | 16 |
|
21 | 17 | @pytest.fixture |
@@ -203,8 +199,10 @@ def rattled_sio2_sim_state(sio2_sim_state: SimState) -> SimState: |
203 | 199 | try: |
204 | 200 | # Temporarily set a fixed seed |
205 | 201 | torch.manual_seed(3) |
206 | | - weibull = torch.distributions.weibull.Weibull(scale=0.5, concentration=1.0) |
207 | | - shifts = weibull.sample((sim_state.n_atoms, 3)) |
| 202 | + weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) |
| 203 | + rnd = torch.randn_like(sim_state.positions) |
| 204 | + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) |
| 205 | + shifts = weibull.sample(rnd.shape) * rnd |
208 | 206 | sim_state.positions = sim_state.positions + shifts |
209 | 207 | finally: |
210 | 208 | # Restore the original RNG state |
@@ -293,35 +291,3 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: |
293 | 291 | compute_stress=True, |
294 | 292 | cutoff=2.5 * 3.405, |
295 | 293 | ) |
296 | | - |
297 | | - |
298 | | -@pytest.fixture |
299 | | -def torchsim_trajectory( |
300 | | - si_sim_state: SimState, |
301 | | - lj_model: Any, |
302 | | - tmp_path: Path, |
303 | | - device: torch.device, |
304 | | - dtype: torch.dtype, |
305 | | -): |
306 | | - """Test NVE integration conserves energy.""" |
307 | | - # Initialize integrator |
308 | | - kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K |
309 | | - dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability |
310 | | - |
311 | | - state, update_fn = nve( |
312 | | - **asdict(si_sim_state), |
313 | | - model=lj_model, |
314 | | - dt=dt, |
315 | | - kT=kT, |
316 | | - ) |
317 | | - |
318 | | - reporter = TrajectoryReporter(tmp_path / "test.hdf5", state_frequency=1) |
319 | | - |
320 | | - # Run several steps |
321 | | - for step in range(10): |
322 | | - state = update_fn(state, dt) |
323 | | - reporter.report(state, step) |
324 | | - |
325 | | - yield reporter.trajectory |
326 | | - |
327 | | - reporter.close() |
0 commit comments