Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,35 @@ def test_training_beta_non_zero(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_cast_lm_head_to_fp32(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
report_to="none",
cast_lm_head_to_fp32=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe here it would make sense to assert that trainer.model.lm_head is indeed in fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

assert trainer.state.log_history[-1]["train_loss"] is not None
assert trainer.model.lm_head.weight.dtype == torch.float32

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_entropy_filter(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
Expand Down
13 changes: 12 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ class GRPOConfig(TrainingArguments):
disable_dropout (`bool`, *optional*, defaults to `False`):
Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
the model from generating different logprobs for the same input.
cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`):
Whether to cast the language modeling head of the policy and reference models to float32. As recommended by
the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model
has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False.

> Parameters that control the data preprocessing

remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
Expand Down Expand Up @@ -297,6 +300,14 @@ class GRPOConfig(TrainingArguments):
"it prevents the model from generating different logprobs for the same input."
},
)
cast_lm_head_to_fp32: bool = field(
default=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When/if we confirm the results from the paper, this could be changed to True in my opinion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm.

I can then plot something similar to the Minimax paper. Will have this done in the next few hours.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Do you have 1/2 models as example?

Copy link
Collaborator Author

@pramodith pramodith Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have tie_word_embeddings set to True and the larger models (like 32B) do not.

Copy link
Collaborator Author

@pramodith pramodith Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some results from Qwen3-32B (had similar results for the smaller models 4B, 8B, 14B). We see marginal improvement in Correlation and Summed absolute differences in probs

qwen32_bf16 qwen32_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3-30B-A3-Thinking similar results.

qwen_moe_bf16 qwen_moe_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the tensor_parallel_size and enable_expert_parallel.

metadata={
"help": "Whether to cast the language modeling head of the policy and reference, models to float32."
"As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model"
" has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False."
},
)

# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
Expand Down
21 changes: 20 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,24 @@ def __init__(
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

# Cast LM Head To FP32
if args.cast_lm_head_to_fp32:
if not model.config.tie_word_embeddings:

def cast_inputs_to_fp32(module, input):
return (input[0].float(),)

model.lm_head = model.lm_head.float()
model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
if self.ref_model is not None:
self.ref_model.lm_head = self.ref_model.lm_head.float()
self.ref_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
else:
raise NotImplementedError(
"`cast_lm_head_to_fp32=True` is only supported when the model has untied word embedding and language modeling head layers"
"i.e. `tie_word_embeddings` in the model config is False."
)

# Liger loss
if self.use_liger_kernel:
if not is_liger_kernel_available():
Expand Down Expand Up @@ -876,7 +894,6 @@ def _get_per_token_logps_and_entropies(
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

completion_ids = input_ids_batch[:, -logits_to_keep:]
logps = selective_log_softmax(logits, completion_ids) # compute logprobs
all_logps.append(logps)
Expand Down Expand Up @@ -1300,6 +1317,8 @@ def _generate_single_turn(self, prompts: list):
unwrapped_model.to(torch.bfloat16)
elif self.args.fp16:
unwrapped_model.to(torch.float16)
if self.args.cast_lm_head_to_fp32:
unwrapped_model.lm_head.to(torch.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I need to register the pre-hook here again, let me know if I'm wrong.

with torch.inference_mode():
# Continuous batching API expects 'inputs' arg only
all_outputs = unwrapped_model.generate_batch(
Expand Down
Loading