diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 612b43228..7b27fe5cb 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -152,9 +152,7 @@ def wrapped_transform( def init_fn(params): count = jnp.zeros([], jnp.int32) if hyperparam_dtype is None: - dtype = getattr( - next(iter(jax.tree.leaves(params)), None), 'dtype', None - ) + dtype = jnp.float32 else: dtype = hyperparam_dtype hparams = { @@ -175,9 +173,7 @@ def init_fn(params): def update_fn(updates, state, params=None, **extra_args): if hyperparam_dtype is None: - dtype = getattr( - next(iter(jax.tree.leaves(updates)), None), 'dtype', None - ) + dtype = jnp.float32 else: dtype = hyperparam_dtype hparams = {