Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ def test_training_with_sync_ref_model(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_beta_non_zero(self):
@parameterized.expand([(False,), (True,)])
def test_training_beta_non_zero(self, cast_lm_head_to_fp32):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
Expand All @@ -671,6 +672,7 @@ def test_training_beta_non_zero(self):
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
cast_lm_head_to_fp32=cast_lm_head_to_fp32,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
Expand Down
11 changes: 10 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ class GRPOConfig(TrainingArguments):
disable_dropout (`bool`, *optional*, defaults to `False`):
Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
the model from generating different logprobs for the same input.
cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`):
Whether to cast the Language Modeling Head of the policy and reference models to float32. As recommended by
the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe.

> Parameters that control the data preprocessing

remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
Expand Down Expand Up @@ -286,6 +288,13 @@ class GRPOConfig(TrainingArguments):
"it prevents the model from generating different logprobs for the same input."
},
)
cast_lm_head_to_fp32: bool = field(
default=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When/if we confirm the results from the paper, this could be changed to True in my opinion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm.

I can then plot something similar to the Minimax paper. Will have this done in the next few hours.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Do you have 1/2 models as example?

Copy link
Collaborator Author

@pramodith pramodith Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have tie_word_embeddings set to True and the larger models (like 32B) do not.

Copy link
Collaborator Author

@pramodith pramodith Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some results from Qwen3-32B (had similar results for the smaller models 4B, 8B, 14B). We see marginal improvement in Correlation and Summed absolute differences in probs

qwen32_bf16 qwen32_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3-30B-A3-Thinking similar results.

qwen_moe_bf16 qwen_moe_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the tensor_parallel_size and enable_expert_parallel.

metadata={
"help": "Whether to cast the Language Modeling Head of the policy and reference, models to float32."
"As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe."
},
)

# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
Expand Down
9 changes: 8 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,12 @@ def __init__(
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

# Cast LM Head To FP32
if args.cast_lm_head_to_fp32:
model.lm_head = model.lm_head.float()
if self.ref_model is not None:
self.ref_model.lm_head = self.ref_model.lm_head.float()

# Liger loss
if self.use_liger_loss:
if not is_liger_kernel_available():
Expand Down Expand Up @@ -842,7 +848,6 @@ def _get_per_token_logps_and_entropies(
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

completion_ids = input_ids_batch[:, -logits_to_keep:]
logps = selective_log_softmax(logits, completion_ids) # compute logprobs
all_logps.append(logps)
Expand Down Expand Up @@ -1249,6 +1254,8 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
unwrapped_model.to(torch.bfloat16)
elif self.args.fp16:
unwrapped_model.to(torch.float16)
if self.args.cast_lm_head_to_fp32:
unwrapped_model.lm_head.to(torch.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I need to register the pre-hook here again, let me know if I'm wrong.

with torch.inference_mode():
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
Expand Down
Loading