diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index d1520df9c..e10b5d357 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -61,7 +61,7 @@ Power iteration .. autofunction:: power_iteration Non-negative least squares -~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: nnls diff --git a/examples/perturbations.ipynb b/examples/perturbations.ipynb index f484f3ee9..166ed95a6 100644 --- a/examples/perturbations.ipynb +++ b/examples/perturbations.ipynb @@ -41,7 +41,6 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "import operator\n", "from jax import tree_util as jtu\n", "\n", "import optax.tree\n", @@ -773,7 +772,7 @@ " pert_softmax = pert_argmax_fun(rng, inputs)\n", " argmax = argmax_tree(inputs)\n", " diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n", - " return jax.tree.reduce(operator.add, diffs)" + " return optax.tree.sum(diffs)" ] }, { diff --git a/optax/_src/utils_test.py b/optax/_src/utils_test.py index 4fe6b57e3..7c3729d6f 100644 --- a/optax/_src/utils_test.py +++ b/optax/_src/utils_test.py @@ -26,6 +26,7 @@ from optax._src import transform from optax._src import update from optax._src import utils +import optax.tree def _shape_to_tuple(shape): @@ -40,8 +41,7 @@ class ScaleGradientTest(parameterized.TestCase): def test_scale_gradient_pytree(self, scale): def fn(inputs): outputs = utils.scale_gradient(inputs, scale) - outputs = jax.tree.map(lambda x: x**2, outputs) - return sum(jax.tree.leaves(outputs)) + return optax.tree.norm(outputs, squared=True) inputs = {'a': -1.0, 'b': {'c': (2.0,), 'd': 0.0}} @@ -50,7 +50,7 @@ def fn(inputs): jax.tree.map(lambda i, g: self.assertEqual(g, 2 * i * scale), inputs, grads) self.assertEqual( fn(inputs), - sum(jax.tree.leaves(jax.tree.map(lambda x: x**2, inputs))), + optax.tree.norm(inputs, squared=True), ) diff --git a/optax/contrib/_privacy_test.py b/optax/contrib/_privacy_test.py index d8d7381e4..6b6c9b127 100644 --- a/optax/contrib/_privacy_test.py +++ b/optax/contrib/_privacy_test.py @@ -20,6 +20,7 @@ import jax import jax.numpy as jnp from optax.contrib import _privacy +import optax.tree class DifferentiallyPrivateAggregateTest(chex.TestCase): @@ -64,12 +65,8 @@ def test_clipping_norm(self, l2_norm_clip): state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) - # Shape of the three arrays below is (self.batch_size, ) - norms = [ - jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) - for g in jax.tree.leaves(self.per_eg_grads) - ] - global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) + global_norms = jax.vmap(optax.tree.norm)(self.per_eg_grads) + divisors = jnp.maximum(global_norms / l2_norm_clip, 1.0) # Since the values of all the parameters are the same within each example, # we can easily compute what the values should be: diff --git a/optax/contrib/_sophia.py b/optax/contrib/_sophia.py index dbdb4b135..14459b0a8 100644 --- a/optax/contrib/_sophia.py +++ b/optax/contrib/_sophia.py @@ -157,10 +157,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs): lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu ) if clip_threshold is not None: - sum_not_clipped = jax.tree.reduce( - lambda x, y: x + y, - jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates), - ) + not_clipped = jax.tree.map(lambda u: jnp.abs(u) < clip_threshold, updates) + sum_not_clipped = optax.tree.sum(not_clipped) if verbose: win_rate = sum_not_clipped / optax.tree.size(updates) jax.lax.cond( diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 309ba6a7b..d7b6c5bbe 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -16,7 +16,6 @@ """Tests for optax.perturbations, checking values and gradients.""" from functools import partial # pylint: disable=g-importing-member -import operator from absl.testing import absltest from absl.testing import parameterized @@ -159,7 +158,7 @@ def loss(tree): pred = apply_element_tree(tree) pred_true = apply_element_tree(example_tree) tree_loss = jax.tree.map(lambda x, y: (x - y) ** 2, pred, pred_true) - list_loss = jax.tree.reduce(operator.add, tree_loss) + list_loss = optax.tree.sum(tree_loss) return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss) loss_pert = jax.jit(_make_pert.make_perturbed_fun( diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 4bc816ebb..7779484f7 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -176,11 +176,8 @@ def skip_not_finite( - `num_not_finite`: total number of inf and NaN found in `updates`. """ del gradient_step, params - all_is_finite = [ - jnp.sum(jnp.logical_not(jnp.isfinite(p))) - for p in jax.tree.leaves(updates) - ] - num_not_finite = jnp.sum(jnp.array(all_is_finite)) + not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates) + num_not_finite = optax.tree.sum(not_finite) should_skip = num_not_finite > 0 return should_skip, { 'should_skip': should_skip, @@ -210,9 +207,7 @@ def skip_large_updates( - `norm_squared`: overall norm square of the `updates`. """ del gradient_step, params - norm_sq = jnp.sum( - jnp.array([jnp.sum(p**2) for p in jax.tree.leaves(updates)]) - ) + norm_sq = optax.tree.norm(updates, squared=True) # This will also return True if `norm_sq` is NaN. should_skip = jnp.logical_not(norm_sq < max_squared_norm) return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq} diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 9387a0976..73e3d6e03 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -151,7 +151,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: numerical issues. """ vdots = jax.tree.map(_vdot_safe, tree_x, tree_y) - return jax.tree.reduce(operator.add, vdots, initializer=0) + return tree_sum(vdots) def tree_sum(tree: Any) -> chex.Numeric: @@ -450,6 +450,4 @@ def tree_allclose( def f(a, b): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) tree = jax.tree.map(f, a, b) - leaves = jax.tree.leaves(tree) - result = functools.reduce(operator.and_, leaves, True) - return result + return jax.tree.reduce(operator.and_, tree, True)