diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 944b78f77..e984af0fa 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -265,7 +265,7 @@ def _iter_body(state): _iter_condition, _iter_body, init_state ) error = jnp.max(jnp.abs(mat_m - identity)) - is_converged = jnp.asarray(convergence, old_mat_h.dtype) + is_converged = jnp.asarray(convergence, old_mat_h.dtype) # pytype: disable=attribute-error # lax-types resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) return resultant_mat_h, error