From fda0234e6fccf8eb328167889130368f08acc120 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Jun 2025 08:31:31 -0700 Subject: [PATCH] Silence some pytype errors related to a JAX build refactor This build change allows pytype to propagate annotations that it previously did not, and because of this it starts flagging existing incorrect annotations. PiperOrigin-RevId: 772038309 --- optax/_src/linear_algebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 944b78f77..e984af0fa 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -265,7 +265,7 @@ def _iter_body(state): _iter_condition, _iter_body, init_state ) error = jnp.max(jnp.abs(mat_m - identity)) - is_converged = jnp.asarray(convergence, old_mat_h.dtype) + is_converged = jnp.asarray(convergence, old_mat_h.dtype) # pytype: disable=attribute-error # lax-types resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) return resultant_mat_h, error