Skip to content

Commit 4c70da7

Browse files
authored
Reduce test flakiness and hotfix docs (#143)
* loosen nequip test in graph pes * fix docs * rename autobatchers * fix typing and clarify return of _configure_batches_iterator
1 parent 414fead commit 4c70da7

File tree

12 files changed

+732
-93
lines changed

12 files changed

+732
-93
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https
125125

126126
TorchSim's structure is summarized in the [API reference](https://radical-ai.github.io/torch-sim/reference/index.html) documentation.
127127

128-
> `torch-sim` module graph. Each node represents a Python module. Arrows indicate imports between modules. Node color indicates connectedness: blue nodes have fewer dependents, red nodes have more (up to 16). The number in parentheses is the number of lines of code in the module.
129-
130128
## License
131129

132130
TorchSim is released under an [MIT license](LICENSE).

docs/reference/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
API reference
44
=============
55

6-
Overview of the torch_sim API.
6+
Overview of the TorchSim API.
77

88
.. currentmodule:: torch_sim
99

@@ -28,6 +28,12 @@ Overview of the torch_sim API.
2828
transforms
2929
units
3030

31+
32+
TorchSim module graph. Each node represents a Python module. Arrows indicate
33+
imports between modules. Node color indicates connectedness: blue nodes have fewer
34+
dependents, red nodes have more (up to 16). The number in parentheses is the number of
35+
lines of code in the module. Click on nodes to navigate to the file.
36+
3137
.. image:: /_static/torch-sim-module-graph.svg
3238
:alt: torch-sim Module Graph
3339
:width: 100%

examples/scripts/4_High_level_api/4.2_auto_batching_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from mace.calculators.foundations_models import mace_mp
1919

2020
from torch_sim.autobatching import (
21-
ChunkingAutoBatcher,
22-
HotSwappingAutoBatcher,
21+
BinningAutoBatcher,
22+
InFlightAutoBatcher,
2323
calculate_memory_scaler,
2424
)
2525
from torch_sim.integrators import nvt_langevin
@@ -65,7 +65,7 @@
6565
# %% TODO: add max steps
6666
converge_max_force = generate_force_convergence_fn(force_tol=1e-1)
6767
single_system_memory = calculate_memory_scaler(fire_states[0])
68-
batcher = HotSwappingAutoBatcher(
68+
batcher = InFlightAutoBatcher(
6969
model=mace_model,
7070
memory_scales_with="n_atoms_x_density",
7171
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,
@@ -86,7 +86,7 @@
8686
print("Total number of completed states", len(all_completed_states))
8787

8888

89-
# %% run chunking autobatcher
89+
# %% run binning autobatcher
9090
nvt_init, nvt_update = nvt_langevin(
9191
model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature
9292
)
@@ -105,7 +105,7 @@
105105

106106

107107
single_system_memory = calculate_memory_scaler(fire_states[0])
108-
batcher = ChunkingAutoBatcher(
108+
batcher = BinningAutoBatcher(
109109
model=mace_model,
110110
memory_scales_with="n_atoms_x_density",
111111
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,

examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py renamed to examples/scripts/5_Workflow/5.3_In_Flight_WBM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype)
6969
)
7070

71-
batcher = ts.autobatching.HotSwappingAutoBatcher(
71+
batcher = ts.autobatching.InFlightAutoBatcher(
7272
model=mace_model,
7373
memory_scales_with="n_atoms_x_density",
7474
max_memory_scaler=1000 if os.getenv("CI") else None,

examples/tutorials/autobatching_tutorial.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
atoms exceeds available GPU memory. The `torch_sim.autobatching` module solves this by:
3030
3131
1. Automatically determining optimal batch sizes based on GPU memory constraints
32-
2. Providing two complementary strategies: chunking and hot-swapping
32+
2. Providing two complementary strategies: binning and in-flight
3333
3. Efficiently managing memory resources during large-scale simulations
3434
3535
Let's explore how to use these powerful features!
@@ -120,9 +120,9 @@ def mock_determine_max_batch_size(*args, **kwargs):
120120
This is a verbose way to determine the max memory metric, we'll see a simpler way
121121
shortly.
122122
123-
## ChunkingAutoBatcher: Fixed Batching Strategy
123+
## BinningAutoBatcher: Fixed Batching Strategy
124124
125-
Now on to the exciting part, autobatching! The `ChunkingAutoBatcher` groups states into
125+
Now on to the exciting part, autobatching! The `BinningAutoBatcher` groups states into
126126
batches with a binpacking algorithm, ensuring that we minimize the total number of
127127
batches while maximizing the GPU utilization of each batch. This approach is ideal for
128128
scenarios where all states need to be processed the same number of times, such as
@@ -132,7 +132,7 @@ def mock_determine_max_batch_size(*args, **kwargs):
132132
"""
133133

134134
# %% Initialize the batcher, the max memory scaler will be computed automatically
135-
batcher = ts.ChunkingAutoBatcher(
135+
batcher = ts.BinningAutoBatcher(
136136
model=mace_model,
137137
memory_scales_with="n_atoms",
138138
)
@@ -167,11 +167,11 @@ def process_batch(batch):
167167
maximum safe batch size through test runs on your GPU. However, the max memory scaler
168168
is typically fixed for a given model and simulation setup. To avoid calculating it
169169
every time, which is a bit slow, you can calculate it once and then include it in the
170-
`ChunkingAutoBatcher` constructor.
170+
`BinningAutoBatcher` constructor.
171171
"""
172172

173173
# %%
174-
batcher = ts.ChunkingAutoBatcher(
174+
batcher = ts.BinningAutoBatcher(
175175
model=mace_model,
176176
memory_scales_with="n_atoms",
177177
max_memory_scaler=max_memory_scaler,
@@ -192,7 +192,7 @@ def process_batch(batch):
192192
nvt_state = nvt_init(state)
193193

194194
# Initialize the batcher
195-
batcher = ts.ChunkingAutoBatcher(
195+
batcher = ts.BinningAutoBatcher(
196196
model=mace_model,
197197
memory_scales_with="n_atoms",
198198
)
@@ -217,13 +217,13 @@ def process_batch(batch):
217217

218218
# %% [markdown]
219219
"""
220-
## HotSwappingAutoBatcher: Dynamic Batching Strategy
220+
## InFlightAutoBatcher: Dynamic Batching Strategy
221221
222-
The `HotSwappingAutoBatcher` optimizes GPU utilization by dynamically removing
222+
The `InFlightAutoBatcher` optimizes GPU utilization by dynamically removing
223223
converged states and adding new ones. This is ideal for processes like geometry
224224
optimization where different states may converge at different rates.
225225
226-
The `HotSwappingAutoBatcher` is more complex than the `ChunkingAutoBatcher` because
226+
The `InFlightAutoBatcher` is more complex than the `BinningAutoBatcher` because
227227
it requires the batch to be dynamically updated. The swapping logic is handled internally,
228228
but the user must regularly provide a convergence tensor indicating which batches in
229229
the state have converged.
@@ -236,7 +236,7 @@ def process_batch(batch):
236236
fire_state = fire_init(state)
237237

238238
# Initialize the batcher
239-
batcher = ts.HotSwappingAutoBatcher(
239+
batcher = ts.InFlightAutoBatcher(
240240
model=mace_model,
241241
memory_scales_with="n_atoms",
242242
max_memory_scaler=1000,
@@ -296,7 +296,7 @@ def process_batch(batch):
296296
"""
297297

298298
# %% Initialize with return_indices=True
299-
batcher = ts.ChunkingAutoBatcher(
299+
batcher = ts.BinningAutoBatcher(
300300
model=mace_model,
301301
memory_scales_with="n_atoms",
302302
max_memory_scaler=80,
@@ -317,8 +317,8 @@ def process_batch(batch):
317317
TorchSim's autobatching provides powerful tools for GPU-efficient simulation of
318318
multiple systems:
319319
320-
1. Use `ChunkingAutoBatcher` for simpler workflows with fixed iteration counts
321-
2. Use `HotSwappingAutoBatcher` for optimization problems with varying convergence
320+
1. Use `BinningAutoBatcher` for simpler workflows with fixed iteration counts
321+
2. Use `InFlightAutoBatcher` for optimization problems with varying convergence
322322
rates
323323
3. Let the library handle memory management automatically, or specify limits manually
324324

0 commit comments

Comments
 (0)