Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,10 @@ def multiclass_perceptron_loss(
return jnp.max(scores, axis=-1) - _dot_last_dim(scores, one_hot_labels)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def poly_loss_cross_entropy(
logits: jax.typing.ArrayLike,
labels: jax.typing.ArrayLike,
*,
epsilon: float = 2.0,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[jax.typing.ArrayLike, None] = None,
Expand Down Expand Up @@ -631,12 +631,12 @@ def convex_kl_divergence(
return x + y


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss_with_forward_probs(
logits: jax.typing.ArrayLike,
logit_paddings: jax.typing.ArrayLike,
labels: jax.typing.ArrayLike,
label_paddings: jax.typing.ArrayLike,
*,
blank_id: int = 0,
log_epsilon: float = -1e5,
) -> tuple[jax.Array, jax.Array, jax.Array]:
Expand Down Expand Up @@ -772,12 +772,12 @@ def loop_body(prev, x):
return per_seq_loss, logalpha_phi, logalpha_emit


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss(
logits: jax.typing.ArrayLike,
logit_paddings: jax.typing.ArrayLike,
labels: jax.typing.ArrayLike,
label_paddings: jax.typing.ArrayLike,
*,
blank_id: int = 0,
log_epsilon: float = -1e5,
) -> jax.Array:
Expand Down Expand Up @@ -818,10 +818,10 @@ def ctc_loss(
return per_seq_loss


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def sigmoid_focal_loss(
logits: jax.typing.ArrayLike,
labels: jax.typing.ArrayLike,
*,
alpha: Optional[float] = None,
gamma: float = 2.0,
) -> jax.Array:
Expand Down
8 changes: 3 additions & 5 deletions optax/losses/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
# ==============================================================================
"""Regression losses."""

import functools
from typing import Optional, Union

import chex
import jax
import jax.numpy as jnp
from optax._src import utils
Expand Down Expand Up @@ -74,10 +72,10 @@ def l2_loss(
return 0.5 * squared_error(predictions, targets)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def huber_loss(
predictions: jax.typing.ArrayLike,
targets: Optional[jax.typing.ArrayLike] = None,
*,
delta: float = 1.0,
) -> jax.Array:
"""Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
Expand Down Expand Up @@ -135,10 +133,10 @@ def log_cosh(
return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_similarity(
predictions: jax.typing.ArrayLike,
targets: jax.typing.ArrayLike,
*,
epsilon: float = 0.0,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[jax.typing.ArrayLike, None] = None,
Expand Down Expand Up @@ -185,10 +183,10 @@ def cosine_similarity(
return (a_unit * b_unit).sum(axis=axis, where=where)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_distance(
predictions: jax.typing.ArrayLike,
targets: jax.typing.ArrayLike,
*,
epsilon: float = 0.0,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[jax.typing.ArrayLike, None] = None,
Expand Down
Loading