Skip to content

Add ListMLE Loss #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_rs/api/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
since your modifications would be overwritten.
"""

from keras_rs.src.losses.list_mle_loss import (
ListMLELoss as ListMLELoss,
)

from keras_rs.src.losses.pairwise_hinge_loss import (
PairwiseHingeLoss as PairwiseHingeLoss,
)
Expand Down
194 changes: 194 additions & 0 deletions keras_rs/src/losses/list_mle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Any

import keras
from keras import ops

from keras_rs.src import types
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
from keras_rs.src.api_export import keras_rs_export
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores


@keras_rs_export("keras_rs.losses.ListMLELoss")
class ListMLELoss(keras.losses.Loss):
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.

ListMLE loss is a listwise ranking loss that maximizes the likelihood of
the ground truth ranking. It works by:
1. Sorting items by their relevance scores (labels)
2. Computing the probability of observing this ranking given the
predicted scores
3. Maximizing this likelihood (minimizing negative log-likelihood)

The loss is computed as the negative log-likelihood of the ground truth
ranking given the predicted scores:

```
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
```

where s_i is the predicted score for item i in the sorted order.

Args:
temperature: Temperature parameter for scaling logits. Higher values
make the probability distribution more uniform. Defaults to 1.0.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`.

Examples:
```python
# Basic usage
loss_fn = ListMLELoss()

# With temperature scaling
loss_fn = ListMLELoss(temperature=0.5)

# Example with synthetic data
y_true = [[3, 2, 1, 0]] # Relevance scores
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
loss = loss_fn(y_true, y_pred)
```
"""

def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
super().__init__(**kwargs)

if temperature <= 0.0:
raise ValueError(
f"`temperature` should be a positive float. Received: "
f"`temperature` = {temperature}."
)

self.temperature = temperature
self._epsilon = 1e-10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define it here, like this, or should we pull it from keras.config.epsilon()? What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking — importing from keras.config.epsilon() would work, but in this case, defining epsilon locally gives us the flexibility to choose a value other than the default 1e-7.


def compute_unreduced_loss(
self,
labels: types.Tensor,
logits: types.Tensor,
mask: types.Tensor | None = None,
) -> tuple[types.Tensor, types.Tensor]:
"""Compute the unreduced ListMLE loss.

Args:
labels: Ground truth relevance scores of
shape [batch_size,list_size].
logits: Predicted scores of shape [batch_size, list_size].
mask: Optional mask of shape [batch_size, list_size].

Returns:
Tuple of (losses, weights) where losses has shape [batch_size, 1]
and weights has the same shape.
"""

valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))

if mask is not None:
valid_mask = ops.logical_and(valid_mask, ops.cast(mask, dtype="bool"))

num_valid_items = ops.sum(ops.cast(valid_mask, dtype=labels.dtype),
axis=1, keepdims=True)

batch_has_valid_items = ops.greater(num_valid_items, 0.0)


labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9))
logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9))

sorted_logits, sorted_valid_mask = sort_by_scores(
tensors_to_sort=[logits_masked, valid_mask],
scores=labels_for_sorting,
mask=None,
shuffle_ties=False,
seed=None
)

sorted_logits = ops.divide(
sorted_logits,
ops.cast(self.temperature, dtype=sorted_logits.dtype)
)

valid_logits_for_max = ops.where(sorted_valid_mask, sorted_logits,
ops.full_like(sorted_logits, -1e9))
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
raw_max = ops.where(batch_has_valid_items, raw_max, ops.zeros_like(raw_max))
sorted_logits = sorted_logits - raw_max

exp_logits = ops.exp(sorted_logits)
exp_logits = ops.where(sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits))

reversed_exp = ops.flip(exp_logits, axis=1)
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)

log_normalizers = ops.log(cumsum_from_right + self._epsilon)
log_probs = sorted_logits - log_normalizers

log_probs = ops.where(sorted_valid_mask, log_probs, ops.zeros_like(log_probs))

negative_log_likelihood = -ops.sum(log_probs, axis=1, keepdims=True)

negative_log_likelihood = ops.where(batch_has_valid_items, negative_log_likelihood,
ops.zeros_like(negative_log_likelihood))

weights = ops.ones_like(negative_log_likelihood)

return negative_log_likelihood, weights

def call(
self,
y_true: types.Tensor,
y_pred: types.Tensor,
) -> types.Tensor:
"""Compute the ListMLE loss.

Args:
y_true: tensor or dict. Ground truth values. If tensor, of shape
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
for batched inputs. If an item has a label of -1, it is ignored
in loss computation. If it is a dictionary, it should have two
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
elements in loss computation.
y_pred: tensor. The predicted values, of shape `(list_size)` for
unbatched inputs or `(batch_size, list_size)` for batched
inputs. Should be of the same shape as `y_true`.

Returns:
The loss tensor of shape [batch_size].
"""
mask = None
if isinstance(y_true, dict):
if "labels" not in y_true:
raise ValueError(
'`"labels"` should be present in `y_true`. Received: '
f"`y_true` = {y_true}"
)

mask = y_true.get("mask", None)
y_true = y_true["labels"]

y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
if mask is not None:
mask = ops.convert_to_tensor(mask)

y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
y_true, y_pred, mask
)

losses, weights = self.compute_unreduced_loss(
labels=y_true, logits=y_pred, mask=mask
)
losses = ops.multiply(losses, weights)
losses = ops.squeeze(losses, axis=-1)
return losses

def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update({"temperature": self.temperature})
return config
87 changes: 87 additions & 0 deletions keras_rs/src/losses/list_mle_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.losses import deserialize
from keras.losses import serialize

from keras_rs.src import testing
from keras_rs.src.losses.list_mle_loss import ListMLELoss

class ListMLELossTest(testing.TestCase, parameterized.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question - have you verified the outputs with TFRS' ListMLELoss?

def setUp(self):
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])

self.batched_scores = ops.array(
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]]
)
self.batched_labels = ops.array(
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]]
)
self.expected_output = ops.array([6.865693, 3.088192])

def test_unbatched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(
y_true=self.unbatched_labels, y_pred=self.unbatched_scores
)
self.assertEqual(output.shape, (1,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertAllClose(output, [self.expected_output[0]], atol=1e-5)

def test_batched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
self.assertEqual(output.shape, (2,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertTrue(ops.convert_to_numpy(output[1]) > 0)
self.assertAllClose(output, self.expected_output, atol=1e-5)

def test_temperature(self):

loss_temp = ListMLELoss(temperature=0.5, reduction="none")
output_temp = loss_temp(y_true=self.batched_labels, y_pred=self.batched_scores)

self.assertAllClose(output_temp,[10.969891,2.1283305],atol=1e-5,
)

def test_invalid_input_rank(self):
rank_1_input = ops.ones((2, 3, 4))

loss = ListMLELoss()
with self.assertRaises(ValueError):
loss(y_true=rank_1_input, y_pred=rank_1_input)

def test_loss_reduction(self):
loss = ListMLELoss(reduction="sum_over_batch_size")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)

self.assertAlmostEqual(ops.convert_to_numpy(output), 4.9769425, places=5)

def test_scalar_sample_weight(self):
sample_weight = ops.array(5.0)
loss = ListMLELoss(reduction="none")

output = loss(
y_true=self.batched_labels,
y_pred=self.batched_scores,
sample_weight=sample_weight,
)

self.assertAllClose(output, self.expected_output * sample_weight, atol=1e-5)

def test_model_fit(self):
inputs = keras.Input(shape=(20,), dtype="float32")
outputs = keras.layers.Dense(5)(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(loss=ListMLELoss(), optimizer="adam")
model.fit(
x=keras.random.normal((2, 20)),
y=keras.random.randint((2, 5), minval=0, maxval=2),
)

def test_serialization(self):
loss = ListMLELoss(temperature=0.8)
restored = deserialize(serialize(loss))
self.assertDictEqual(loss.get_config(), restored.get_config())