Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def __init__(
)

# Loss function
if args.loss_type == "nll":
if args.loss_type == "nll" or args.use_liger_kernel:
pass # use the default loss
elif args.loss_type == "dft":
if compute_loss_func is not None:
Expand Down Expand Up @@ -1095,6 +1095,11 @@ def compute_loss(

# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
inputs["use_cache"] = False
# Request token accuracy from Liger kernel and set token scaling if using DFT loss
if self.args.use_liger_kernel:
inputs["return_token_accuracy"] = True
inputs["use_token_scaling"] = self.args.loss_type == "dft"

(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
Expand Down Expand Up @@ -1133,8 +1138,12 @@ def compute_loss(
self._total_train_tokens += num_tokens_in_batch
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if not self.args.use_liger_kernel:
if self.args.use_liger_kernel:
if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer using an explicit condition, like this:

if self.args.use_liger_kernel and version("liger_kernel") >= "0.6.4":
    token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item()
    self._metrics[mode]["mean_token_accuracy"].append(token_accuracy)

Having the explicit condition makes it much easier to spot when the check can be removed after we bump the minimum liger_kernel version. It’s clearer and more maintainable than hiding the logic elsewhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the question is then, which version will include this feature

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue with this is that I cannot test it since the dev version is the current version in their setup

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah true

token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item()
self._metrics[mode]["mean_token_accuracy"].append(token_accuracy)
else:
# Compute accuracy from logits using argmax (traditional method)
with torch.no_grad():
if "shift_labels" in inputs:
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
Expand Down Expand Up @@ -1172,10 +1181,12 @@ def compute_loss(
total_sum = total_tokens.sum()
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
if self.aux_loss_enabled:
aux_loss = outputs.aux_loss
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
self._metrics[mode]["aux_loss"].append(aux_loss)

# Log auxiliary loss if enabled (applies to both Liger and non-Liger)
if self.aux_loss_enabled:
aux_loss = outputs.aux_loss
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
self._metrics[mode]["aux_loss"].append(aux_loss)

return (loss, outputs) if return_outputs else loss

Expand Down
Loading