diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index d7917ebe4..ed30a20c5 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -466,10 +466,10 @@ def multiclass_perceptron_loss( return jnp.max(scores, axis=-1) - _dot_last_dim(scores, one_hot_labels) -@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def poly_loss_cross_entropy( logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, + *, epsilon: float = 2.0, axis: Union[int, tuple[int, ...], None] = -1, where: Union[jax.typing.ArrayLike, None] = None, @@ -631,12 +631,12 @@ def convex_kl_divergence( return x + y -@functools.partial(chex.warn_only_n_pos_args_in_future, n=4) def ctc_loss_with_forward_probs( logits: jax.typing.ArrayLike, logit_paddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, label_paddings: jax.typing.ArrayLike, + *, blank_id: int = 0, log_epsilon: float = -1e5, ) -> tuple[jax.Array, jax.Array, jax.Array]: @@ -772,12 +772,12 @@ def loop_body(prev, x): return per_seq_loss, logalpha_phi, logalpha_emit -@functools.partial(chex.warn_only_n_pos_args_in_future, n=4) def ctc_loss( logits: jax.typing.ArrayLike, logit_paddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, label_paddings: jax.typing.ArrayLike, + *, blank_id: int = 0, log_epsilon: float = -1e5, ) -> jax.Array: @@ -818,10 +818,10 @@ def ctc_loss( return per_seq_loss -@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def sigmoid_focal_loss( logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, + *, alpha: Optional[float] = None, gamma: float = 2.0, ) -> jax.Array: diff --git a/optax/losses/_regression.py b/optax/losses/_regression.py index 6762be45a..6952429e0 100644 --- a/optax/losses/_regression.py +++ b/optax/losses/_regression.py @@ -14,10 +14,8 @@ # ============================================================================== """Regression losses.""" -import functools from typing import Optional, Union -import chex import jax import jax.numpy as jnp from optax._src import utils @@ -74,10 +72,10 @@ def l2_loss( return 0.5 * squared_error(predictions, targets) -@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def huber_loss( predictions: jax.typing.ArrayLike, targets: Optional[jax.typing.ArrayLike] = None, + *, delta: float = 1.0, ) -> jax.Array: """Huber loss, similar to L2 loss close to zero, L1 loss away from zero. @@ -135,10 +133,10 @@ def log_cosh( return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) -@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def cosine_similarity( predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, + *, epsilon: float = 0.0, axis: Union[int, tuple[int, ...], None] = -1, where: Union[jax.typing.ArrayLike, None] = None, @@ -185,10 +183,10 @@ def cosine_similarity( return (a_unit * b_unit).sum(axis=axis, where=where) -@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def cosine_distance( predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, + *, epsilon: float = 0.0, axis: Union[int, tuple[int, ...], None] = -1, where: Union[jax.typing.ArrayLike, None] = None,