Skip to content

Conversation

@Aaryan-549
Copy link

The helper function _tx is defined with parameters (g, d, g_sos), but jax.tree.map was calling it with (max_dist, g_sos, updates). This caused _tx to interpret the accumulated gradient sum-of-squares as d and the raw updates (which can be negative) as g_sos. When the raw gradient had any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a negative argument and produce NaN.

This fix corrects the argument order to (updates, max_dist, g_sos) to match the function signature.

The helper function _tx is defined with parameters (g, d, g_sos), but
jax.tree.map was calling it with (max_dist, g_sos, updates). This caused
_tx to interpret the accumulated gradient sum-of-squares as d and the
raw updates (which can be negative) as g_sos. When the raw gradient had
any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a
negative argument and produce NaN.

This fix corrects the argument order to (updates, max_dist, g_sos) to
match the function signature.
@Aaryan-549 Aaryan-549 force-pushed the fix-scale-by-distance-arg-order branch from 9caf80a to 292dcbf Compare November 18, 2025 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant