Skip to content

Conversation

@Aaryan-549
Copy link

Fixes #1500

The momo and momo_adam update functions declared the value parameter as Optional[jax.Array] but never enforced that type. When calling value.astype(state.barf.dtype), passing a Python float or NumPy scalar would raise AttributeError since these types don't have an astype method.

This fix uses jnp.asarray(value, dtype=state.barf.dtype) instead of value.astype(state.barf.dtype) to handle Python floats, NumPy scalars, and JAX arrays uniformly.

@rdyro
Copy link
Collaborator

rdyro commented Nov 17, 2025

Thanks! Can you adjust the type annotation to reflect the value requirement? Should probably drop Optional and use jax.typing.ArrayLike

@Aaryan-549 Aaryan-549 force-pushed the fix-momo-python-float-1500 branch from d181d3a to cf081e8 Compare November 18, 2025 04:31
@Aaryan-549
Copy link
Author

@rdyro Thanks for the review! I've made the suggested changes and updated the type annotation to use jax.typing.ArrayLike. Ready for another look.

Fixes google-deepmind#1500

The momo and momo_adam update functions declared the value parameter as
Optional[jax.Array] but never enforced that type. When calling
value.astype(state.barf.dtype), passing a Python float or NumPy scalar
would raise AttributeError since these types don't have an astype method.

This fix uses jnp.asarray(value, dtype=state.barf.dtype) instead of
value.astype(state.barf.dtype) to handle Python floats, NumPy scalars,
and JAX arrays uniformly.
@Aaryan-549 Aaryan-549 force-pushed the fix-momo-python-float-1500 branch from cf081e8 to 3482972 Compare November 18, 2025 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

contrib.momo crashes when loss value is a Python float

2 participants