Skip to content

Commit 5356d0b

Browse files
CompRhysjanosh
andcommitted
fea: use batched vdot
--------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent ef5c912 commit 5356d0b

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

torch_sim/math.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,3 +987,28 @@ def matrix_log_33(
987987
print(msg)
988988
# Fall back to scipy implementation
989989
return matrix_log_scipy(matrix).to(sim_dtype)
990+
991+
992+
def batched_vdot(
993+
x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor
994+
) -> torch.Tensor:
995+
"""Computes batched vdot (sum of element-wise product) for groups of vectors.
996+
If is_sum_sq is True, computes sum of x_i * x_i (squared norm components).
997+
998+
Args:
999+
x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities).
1000+
y: Tensor of shape [N_total_entities, D]. Ignored if is_sum_sq is True.
1001+
batch_indices: Tensor of shape [N_total_entities] indicating batch membership.
1002+
1003+
Returns:
1004+
Tensor: shape [n_batches] where each element is the sum(x_i * y_i)
1005+
(or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch,
1006+
summed over all components D and all entities in the batch.
1007+
"""
1008+
if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]:
1009+
raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}")
1010+
1011+
output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)
1012+
output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
1013+
1014+
return output

torch_sim/optimizers.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,14 +1481,28 @@ def _ase_fire_step( # noqa: C901, PLR0915
14811481

14821482
# 3. Velocity mixing BEFORE acceleration (ASE ordering)
14831483
# Atoms
1484-
v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True)
1485-
f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True)
1486-
f_unit_atom = state.forces / (f_norm_atom + eps)
1487-
alpha_atom = state.alpha[state.batch].unsqueeze(-1)
1488-
pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1)
1489-
v_new_atom = (
1490-
1.0 - alpha_atom
1491-
) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom
1484+
# print(f"{state.velocities.shape=}")
1485+
v_sum_sq_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch)
1486+
# sum_sq per batch, shape [n_batches]
1487+
f_sum_sq_batch = tsm.batched_vdot(state.forces, state.forces, state.batch)
1488+
# sum_sq per batch, shape [n_batches]
1489+
1490+
# Expand to per-atom for applying to vectors
1491+
# These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch
1492+
# Effectively |V|_batch and |F|_batch for the mixing formula
1493+
sqrt_v_sum_sq_batch_expanded = torch.sqrt(v_sum_sq_batch[state.batch].unsqueeze(-1))
1494+
sqrt_f_sum_sq_batch_expanded = torch.sqrt(f_sum_sq_batch[state.batch].unsqueeze(-1))
1495+
1496+
alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha
1497+
pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) # per-atom mask
1498+
1499+
# ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch
1500+
# = (1-a)*v + a * f * (|V|_batch / |F|_batch)
1501+
mixing_term_atom = state.forces * (
1502+
sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps)
1503+
)
1504+
1505+
v_new_atom = (1.0 - alpha_atom) * state.velocities + alpha_atom * mixing_term_atom
14921506
state.velocities = torch.where(
14931507
pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities)
14941508
)
@@ -1524,12 +1538,22 @@ def _ase_fire_step( # noqa: C901, PLR0915
15241538
dr_cell = cell_dt * state.cell_velocities
15251539

15261540
# 6. Clamp to max_step
1527-
dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True)
1528-
mask_atom_max_step = dr_norm_atom > max_step
1529-
dr_atom = torch.where(
1530-
mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom
1541+
dr_atom_sum_sq_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch)
1542+
norm_dr_atom_per_batch = torch.sqrt(dr_atom_sum_sq_batch) # shape [n_batches]
1543+
1544+
mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches]
1545+
1546+
scaling_factor_batch = torch.ones_like(norm_dr_atom_per_batch)
1547+
safe_norm_for_clamped_batches = norm_dr_atom_per_batch[mask_clamp_batch]
1548+
scaling_factor_batch[mask_clamp_batch] = max_step / (
1549+
safe_norm_for_clamped_batches + eps
15311550
)
15321551

1552+
# shape [N_atoms, 1]
1553+
atom_wise_scaling_factor = scaling_factor_batch[state.batch].unsqueeze(-1)
1554+
1555+
dr_atom = dr_atom * atom_wise_scaling_factor
1556+
15331557
old_row_vector_cell: torch.Tensor | None = None
15341558
if is_cell_optimization:
15351559
assert isinstance(state, (UnitCellFireState, FrechetCellFIREState))

0 commit comments

Comments
 (0)