Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Power iteration
.. autofunction:: power_iteration

Non-negative least squares
~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: nnls


Expand Down
3 changes: 1 addition & 2 deletions examples/perturbations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import operator\n",
"from jax import tree_util as jtu\n",
"\n",
"import optax.tree\n",
Expand Down Expand Up @@ -773,7 +772,7 @@
" pert_softmax = pert_argmax_fun(rng, inputs)\n",
" argmax = argmax_tree(inputs)\n",
" diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n",
" return jax.tree.reduce(operator.add, diffs)"
" return optax.tree.sum(diffs)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions optax/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optax._src import transform
from optax._src import update
from optax._src import utils
import optax.tree


def _shape_to_tuple(shape):
Expand All @@ -40,8 +41,7 @@ class ScaleGradientTest(parameterized.TestCase):
def test_scale_gradient_pytree(self, scale):
def fn(inputs):
outputs = utils.scale_gradient(inputs, scale)
outputs = jax.tree.map(lambda x: x**2, outputs)
return sum(jax.tree.leaves(outputs))
return optax.tree.norm(outputs, squared=True)

inputs = {'a': -1.0, 'b': {'c': (2.0,), 'd': 0.0}}

Expand All @@ -50,7 +50,7 @@ def fn(inputs):
jax.tree.map(lambda i, g: self.assertEqual(g, 2 * i * scale), inputs, grads)
self.assertEqual(
fn(inputs),
sum(jax.tree.leaves(jax.tree.map(lambda x: x**2, inputs))),
optax.tree.norm(inputs, squared=True),
)


Expand Down
9 changes: 3 additions & 6 deletions optax/contrib/_privacy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax
import jax.numpy as jnp
from optax.contrib import _privacy
import optax.tree


class DifferentiallyPrivateAggregateTest(chex.TestCase):
Expand Down Expand Up @@ -64,12 +65,8 @@ def test_clipping_norm(self, l2_norm_clip):
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)

# Shape of the three arrays below is (self.batch_size, )
norms = [
jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
for g in jax.tree.leaves(self.per_eg_grads)
]
global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
global_norms = jax.vmap(optax.tree.norm)(self.per_eg_grads)

divisors = jnp.maximum(global_norms / l2_norm_clip, 1.0)
# Since the values of all the parameters are the same within each example,
# we can easily compute what the values should be:
Expand Down
6 changes: 2 additions & 4 deletions optax/contrib/_sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu
)
if clip_threshold is not None:
sum_not_clipped = jax.tree.reduce(
lambda x, y: x + y,
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
)
not_clipped = jax.tree.map(lambda u: jnp.abs(u) < clip_threshold, updates)
sum_not_clipped = optax.tree.sum(not_clipped)
if verbose:
win_rate = sum_not_clipped / optax.tree.size(updates)
jax.lax.cond(
Expand Down
3 changes: 1 addition & 2 deletions optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Tests for optax.perturbations, checking values and gradients."""

from functools import partial # pylint: disable=g-importing-member
import operator

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -159,7 +158,7 @@ def loss(tree):
pred = apply_element_tree(tree)
pred_true = apply_element_tree(example_tree)
tree_loss = jax.tree.map(lambda x, y: (x - y) ** 2, pred, pred_true)
list_loss = jax.tree.reduce(operator.add, tree_loss)
list_loss = optax.tree.sum(tree_loss)
return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss)

loss_pert = jax.jit(_make_pert.make_perturbed_fun(
Expand Down
11 changes: 3 additions & 8 deletions optax/transforms/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,8 @@ def skip_not_finite(
- `num_not_finite`: total number of inf and NaN found in `updates`.
"""
del gradient_step, params
all_is_finite = [
jnp.sum(jnp.logical_not(jnp.isfinite(p)))
for p in jax.tree.leaves(updates)
]
num_not_finite = jnp.sum(jnp.array(all_is_finite))
not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates)
num_not_finite = optax.tree.sum(not_finite)
should_skip = num_not_finite > 0
return should_skip, {
'should_skip': should_skip,
Expand Down Expand Up @@ -210,9 +207,7 @@ def skip_large_updates(
- `norm_squared`: overall norm square of the `updates`.
"""
del gradient_step, params
norm_sq = jnp.sum(
jnp.array([jnp.sum(p**2) for p in jax.tree.leaves(updates)])
)
norm_sq = optax.tree.norm(updates, squared=True)
# This will also return True if `norm_sq` is NaN.
should_skip = jnp.logical_not(norm_sq < max_squared_norm)
return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq}
Expand Down
6 changes: 2 additions & 4 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
numerical issues.
"""
vdots = jax.tree.map(_vdot_safe, tree_x, tree_y)
return jax.tree.reduce(operator.add, vdots, initializer=0)
return tree_sum(vdots)


def tree_sum(tree: Any) -> chex.Numeric:
Expand Down Expand Up @@ -450,6 +450,4 @@ def tree_allclose(
def f(a, b):
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
tree = jax.tree.map(f, a, b)
leaves = jax.tree.leaves(tree)
result = functools.reduce(operator.and_, leaves, True)
return result
return jax.tree.reduce(operator.and_, tree, True)
Loading