Skip to content

[Enhancement] Initial evalulation of state causes an OOM error if state is too large #128

@YutackPark

Description

@YutackPark

To reproduce, we can increase many_cu_atoms (50 => 5000 was enough to raise OOM error in my device)

  import torch
  import torch_sim as ts
  from mace.calculators.foundations_models import mace_mp
  from torch_sim.models import MaceModel

  device = torch.device("cuda")
  mace = mace_mp(model="small", return_raw_model=True)
  mace_model = MaceModel(model=mace, device=device)

  from ase.build import bulk
  cu_atoms = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2))
  many_cu_atoms = [cu_atoms] * 5000  # << 50 => 5000 to raise OOM error
  trajectory_files = [f"Cu_traj_{i}.h5md" for i in range(len(many_cu_atoms))]

  # run them all simultaneously with batching
  final_state = ts.integrate(
      system=many_cu_atoms,
      model=mace_model,
      n_steps=50,
      timestep=0.002,
      temperature=1000,
      integrator=ts.nvt_langevin,
      trajectory_reporter=dict(filenames=trajectory_files, state_frequency=10),
      autobatcher=True
  )
  final_atoms_list = final_state.to_atoms()

Relevant backtrace:

In [7]: # run them all simultaneously with batching
   ...: final_state = ts.optimize(
   ...:     system=many_cu_atoms,
   ...:     model=mace_model,
   ...:     optimizer=ts.frechet_cell_fire,
   ...:     autobatcher=True
   ...:
   ...: )
   ...: final_atoms_list = final_state.to_atoms()
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[7], line 2
      1 # run them all simultaneously with batching
----> 2 final_state = ts.optimize(
      3     system=many_cu_atoms,
      4     model=mace_model,
      5     optimizer=ts.frechet_cell_fire,
      6     autobatcher=True
      7
      8 )
      9 final_atoms_list = final_state.to_atoms()

File ~/.local/venv11/mace/lib/python3.11/site-packages/torch_sim/runners.py:317, in optimize(system, model, optimizer, convergence_fn, trajectory_reporter, autobatcher, max_steps, steps_between_swaps, **optimizer_kwargs)
    315 state: SimState = initialize_state(system, model.device, model.dtype)
    316 init_fn, update_fn = optimizer(model=model, **optimizer_kwargs)
--> 317 state = init_fn(state)
    319 max_attempts = max_steps // steps_between_swaps
    320 autobatcher = _configure_hot_swapping_autobatcher(
    321     model, state, autobatcher, max_attempts
    322 )

File ~/.local/venv11/mace/lib/python3.11/site-packages/torch_sim/optimizers.py:1058, in frechet_cell_fire.<locals>.fire_init(state, cell_factor, scalar_pressure, dt_start, alpha_start)
   1055 pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1)
   1057 # Get initial forces and energy from model
-> 1058 model_output = model(state)
   1060 energy = model_output["energy"]  # [n_batches]
   1061 forces = model_output["forces"]  # [n_total_atoms, 3]

While torch-sim has an autobatch algorithm to mitigate an OOM error, this is not applied in the init_fn stage of many optimizers and integrators.
The state = init_fn(state) line should be placed after the autobatcher, and should take autobatcher as an argument to use its batch algorithm from the first place.
Moreover, the torch_sim.state.initialize_state function sends all the state tensors to the GPU. If the state is too large, this is another potential place for an OOM error.

Personally, I think it has high priority as users of torch-sim would want to evaluate a large number of different systems in parallel, and there is no maximum number for it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions