Fix argument order in scale_by_distance_over_gradients #1501
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The helper function _tx is defined with parameters (g, d, g_sos), but jax.tree.map was calling it with (max_dist, g_sos, updates). This caused _tx to interpret the accumulated gradient sum-of-squares as d and the raw updates (which can be negative) as g_sos. When the raw gradient had any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a negative argument and produce NaN.
This fix corrects the argument order to (updates, max_dist, g_sos) to match the function signature.