-
Notifications
You must be signed in to change notification settings - Fork 55
Description
To reproduce, we can increase many_cu_atoms (50 => 5000 was enough to raise OOM error in my device)
import torch
import torch_sim as ts
from mace.calculators.foundations_models import mace_mp
from torch_sim.models import MaceModel
device = torch.device("cuda")
mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(model=mace, device=device)
from ase.build import bulk
cu_atoms = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2))
many_cu_atoms = [cu_atoms] * 5000 # << 50 => 5000 to raise OOM error
trajectory_files = [f"Cu_traj_{i}.h5md" for i in range(len(many_cu_atoms))]
# run them all simultaneously with batching
final_state = ts.integrate(
system=many_cu_atoms,
model=mace_model,
n_steps=50,
timestep=0.002,
temperature=1000,
integrator=ts.nvt_langevin,
trajectory_reporter=dict(filenames=trajectory_files, state_frequency=10),
autobatcher=True
)
final_atoms_list = final_state.to_atoms()Relevant backtrace:
In [7]: # run them all simultaneously with batching
...: final_state = ts.optimize(
...: system=many_cu_atoms,
...: model=mace_model,
...: optimizer=ts.frechet_cell_fire,
...: autobatcher=True
...:
...: )
...: final_atoms_list = final_state.to_atoms()
---------------------------------------------------------------------------
OutOfMemoryError Traceback (most recent call last)
Cell In[7], line 2
1 # run them all simultaneously with batching
----> 2 final_state = ts.optimize(
3 system=many_cu_atoms,
4 model=mace_model,
5 optimizer=ts.frechet_cell_fire,
6 autobatcher=True
7
8 )
9 final_atoms_list = final_state.to_atoms()
File ~/.local/venv11/mace/lib/python3.11/site-packages/torch_sim/runners.py:317, in optimize(system, model, optimizer, convergence_fn, trajectory_reporter, autobatcher, max_steps, steps_between_swaps, **optimizer_kwargs)
315 state: SimState = initialize_state(system, model.device, model.dtype)
316 init_fn, update_fn = optimizer(model=model, **optimizer_kwargs)
--> 317 state = init_fn(state)
319 max_attempts = max_steps // steps_between_swaps
320 autobatcher = _configure_hot_swapping_autobatcher(
321 model, state, autobatcher, max_attempts
322 )
File ~/.local/venv11/mace/lib/python3.11/site-packages/torch_sim/optimizers.py:1058, in frechet_cell_fire.<locals>.fire_init(state, cell_factor, scalar_pressure, dt_start, alpha_start)
1055 pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1)
1057 # Get initial forces and energy from model
-> 1058 model_output = model(state)
1060 energy = model_output["energy"] # [n_batches]
1061 forces = model_output["forces"] # [n_total_atoms, 3]While torch-sim has an autobatch algorithm to mitigate an OOM error, this is not applied in the init_fn stage of many optimizers and integrators.
The state = init_fn(state) line should be placed after the autobatcher, and should take autobatcher as an argument to use its batch algorithm from the first place.
Moreover, the torch_sim.state.initialize_state function sends all the state tensors to the GPU. If the state is too large, this is another potential place for an OOM error.
Personally, I think it has high priority as users of torch-sim would want to evaluate a large number of different systems in parallel, and there is no maximum number for it.