@@ -1481,14 +1481,28 @@ def _ase_fire_step( # noqa: C901, PLR0915
14811481
14821482 # 3. Velocity mixing BEFORE acceleration (ASE ordering)
14831483 # Atoms
1484- v_norm_atom = torch .norm (state .velocities , dim = 1 , keepdim = True )
1485- f_norm_atom = torch .norm (state .forces , dim = 1 , keepdim = True )
1486- f_unit_atom = state .forces / (f_norm_atom + eps )
1487- alpha_atom = state .alpha [state .batch ].unsqueeze (- 1 )
1488- pos_mask_atom = pos_mask_batch [state .batch ].unsqueeze (- 1 )
1489- v_new_atom = (
1490- 1.0 - alpha_atom
1491- ) * state .velocities + alpha_atom * f_unit_atom * v_norm_atom
1484+ # print(f"{state.velocities.shape=}")
1485+ v_sum_sq_batch = tsm .batched_vdot (state .velocities , state .velocities , state .batch )
1486+ # sum_sq per batch, shape [n_batches]
1487+ f_sum_sq_batch = tsm .batched_vdot (state .forces , state .forces , state .batch )
1488+ # sum_sq per batch, shape [n_batches]
1489+
1490+ # Expand to per-atom for applying to vectors
1491+ # These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch
1492+ # Effectively |V|_batch and |F|_batch for the mixing formula
1493+ sqrt_v_sum_sq_batch_expanded = torch .sqrt (v_sum_sq_batch [state .batch ].unsqueeze (- 1 ))
1494+ sqrt_f_sum_sq_batch_expanded = torch .sqrt (f_sum_sq_batch [state .batch ].unsqueeze (- 1 ))
1495+
1496+ alpha_atom = state .alpha [state .batch ].unsqueeze (- 1 ) # per-atom alpha
1497+ pos_mask_atom = pos_mask_batch [state .batch ].unsqueeze (- 1 ) # per-atom mask
1498+
1499+ # ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch
1500+ # = (1-a)*v + a * f * (|V|_batch / |F|_batch)
1501+ mixing_term_atom = state .forces * (
1502+ sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps )
1503+ )
1504+
1505+ v_new_atom = (1.0 - alpha_atom ) * state .velocities + alpha_atom * mixing_term_atom
14921506 state .velocities = torch .where (
14931507 pos_mask_atom , v_new_atom , torch .zeros_like (state .velocities )
14941508 )
@@ -1524,12 +1538,22 @@ def _ase_fire_step( # noqa: C901, PLR0915
15241538 dr_cell = cell_dt * state .cell_velocities
15251539
15261540 # 6. Clamp to max_step
1527- dr_norm_atom = torch .norm (dr_atom , dim = 1 , keepdim = True )
1528- mask_atom_max_step = dr_norm_atom > max_step
1529- dr_atom = torch .where (
1530- mask_atom_max_step , max_step * dr_atom / (dr_norm_atom + eps ), dr_atom
1541+ dr_atom_sum_sq_batch = tsm .batched_vdot (dr_atom , dr_atom , state .batch )
1542+ norm_dr_atom_per_batch = torch .sqrt (dr_atom_sum_sq_batch ) # shape [n_batches]
1543+
1544+ mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches]
1545+
1546+ scaling_factor_batch = torch .ones_like (norm_dr_atom_per_batch )
1547+ safe_norm_for_clamped_batches = norm_dr_atom_per_batch [mask_clamp_batch ]
1548+ scaling_factor_batch [mask_clamp_batch ] = max_step / (
1549+ safe_norm_for_clamped_batches + eps
15311550 )
15321551
1552+ # shape [N_atoms, 1]
1553+ atom_wise_scaling_factor = scaling_factor_batch [state .batch ].unsqueeze (- 1 )
1554+
1555+ dr_atom = dr_atom * atom_wise_scaling_factor
1556+
15331557 old_row_vector_cell : torch .Tensor | None = None
15341558 if is_cell_optimization :
15351559 assert isinstance (state , (UnitCellFireState , FrechetCellFIREState ))
0 commit comments