From 292dcbf796611178c6694d6b0b8e95d51483afe6 Mon Sep 17 00:00:00 2001 From: Aaryan-549 Date: Tue, 18 Nov 2025 00:27:20 +0530 Subject: [PATCH] Fix argument order in scale_by_distance_over_gradients The helper function _tx is defined with parameters (g, d, g_sos), but jax.tree.map was calling it with (max_dist, g_sos, updates). This caused _tx to interpret the accumulated gradient sum-of-squares as d and the raw updates (which can be negative) as g_sos. When the raw gradient had any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a negative argument and produce NaN. This fix corrects the argument order to (updates, max_dist, g_sos) to match the function signature. --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 857965f3b..c5b7488fc 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1406,7 +1406,7 @@ def _tx(g, d, g_sos): eta = global_scale * (d / jnp.sqrt(g_sos + eps)) return eta * g - updates = jax.tree.map(_tx, max_dist, g_sos, updates) + updates = jax.tree.map(_tx, updates, max_dist, g_sos) # new state state = ScaleByDistanceOverGradientsState(