Skip to content

Commit 60e80b3

Browse files
orionarcherabhijeetganganCompRhys
authored
Support different temperatures in batches (#123)
* allow different batches to be run at different temperatures in nvt, test functionality * add zeroed COM motion to calculate_momenta * modify integrate function to convert temps to kTs earlier * remove unused fixture * allow calculate momenta to take float * lint * only do kT[batch] call if it is a tensor * fix sio2 rattled system * throw error if sevenet is created with float64 * turn sevennet failure on float64 to warning * try fixing OOM on state initialization * lint --------- Co-authored-by: Abhijeet Gangan <[email protected]> Co-authored-by: comprhys <[email protected]>
1 parent 86cff98 commit 60e80b3

File tree

5 files changed

+252
-268
lines changed

5 files changed

+252
-268
lines changed

tests/conftest.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from dataclasses import asdict
2-
from pathlib import Path
31
from typing import Any
42

53
import pytest
@@ -13,9 +11,7 @@
1311
from torch_sim.io import atoms_to_state
1412
from torch_sim.models.lennard_jones import LennardJonesModel
1513
from torch_sim.state import SimState, concatenate_states
16-
from torch_sim.trajectory import TrajectoryReporter
1714
from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel
18-
from torch_sim.unbatched.unbatched_integrators import nve
1915

2016

2117
@pytest.fixture
@@ -203,8 +199,10 @@ def rattled_sio2_sim_state(sio2_sim_state: SimState) -> SimState:
203199
try:
204200
# Temporarily set a fixed seed
205201
torch.manual_seed(3)
206-
weibull = torch.distributions.weibull.Weibull(scale=0.5, concentration=1.0)
207-
shifts = weibull.sample((sim_state.n_atoms, 3))
202+
weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1)
203+
rnd = torch.randn_like(sim_state.positions)
204+
rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True)
205+
shifts = weibull.sample(rnd.shape) * rnd
208206
sim_state.positions = sim_state.positions + shifts
209207
finally:
210208
# Restore the original RNG state
@@ -293,35 +291,3 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
293291
compute_stress=True,
294292
cutoff=2.5 * 3.405,
295293
)
296-
297-
298-
@pytest.fixture
299-
def torchsim_trajectory(
300-
si_sim_state: SimState,
301-
lj_model: Any,
302-
tmp_path: Path,
303-
device: torch.device,
304-
dtype: torch.dtype,
305-
):
306-
"""Test NVE integration conserves energy."""
307-
# Initialize integrator
308-
kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K
309-
dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability
310-
311-
state, update_fn = nve(
312-
**asdict(si_sim_state),
313-
model=lj_model,
314-
dt=dt,
315-
kT=kT,
316-
)
317-
318-
reporter = TrajectoryReporter(tmp_path / "test.hdf5", state_frequency=1)
319-
320-
# Run several steps
321-
for step in range(10):
322-
state = update_fn(state, dt)
323-
reporter.report(state, step)
324-
325-
yield reporter.trajectory
326-
327-
reporter.close()

0 commit comments

Comments
 (0)