-
Notifications
You must be signed in to change notification settings - Fork 2.3k
GRPO: ScaleRL -> Support casting LM Head to FP32 #4303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f8c9185
5156a63
a216558
6b80f7d
d97e13c
abced5a
1b19f69
ba565bb
ef22d09
6972ed2
0f258e0
198d9a1
bad967e
6a64e47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,9 +41,12 @@ class GRPOConfig(TrainingArguments): | |
| disable_dropout (`bool`, *optional*, defaults to `False`): | ||
| Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents | ||
| the model from generating different logprobs for the same input. | ||
| cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`): | ||
| Whether to cast the language modeling head of the policy and reference models to float32. As recommended by | ||
| the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model | ||
| has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False. | ||
|
|
||
| > Parameters that control the data preprocessing | ||
|
|
||
| remove_unused_columns (`bool`, *optional*, defaults to `False`): | ||
| Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that | ||
| requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. | ||
|
|
@@ -297,6 +300,14 @@ class GRPOConfig(TrainingArguments): | |
| "it prevents the model from generating different logprobs for the same input." | ||
| }, | ||
| ) | ||
| cast_lm_head_to_fp32: bool = field( | ||
| default=False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When/if we confirm the results from the paper, this could be changed to True in my opinion.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm. I can then plot something similar to the Minimax paper. Will have this done in the next few hours.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! Do you have 1/2 models as example?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the |
||
| metadata={ | ||
| "help": "Whether to cast the language modeling head of the policy and reference, models to float32." | ||
| "As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model" | ||
| " has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False." | ||
| }, | ||
| ) | ||
|
|
||
| # Parameters that control the data preprocessing | ||
| # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -477,6 +477,24 @@ def __init__( | |
| if self.ref_model is not None: | ||
| disable_dropout_in_model(self.ref_model) | ||
|
|
||
| # Cast LM Head To FP32 | ||
| if args.cast_lm_head_to_fp32: | ||
| if not model.config.tie_word_embeddings: | ||
|
|
||
| def cast_inputs_to_fp32(module, input): | ||
| return (input[0].float(),) | ||
|
|
||
| model.lm_head = model.lm_head.float() | ||
| model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) | ||
| if self.ref_model is not None: | ||
| self.ref_model.lm_head = self.ref_model.lm_head.float() | ||
| self.ref_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) | ||
| else: | ||
| raise NotImplementedError( | ||
| "`cast_lm_head_to_fp32=True` is only supported when the model has untied word embedding and language modeling head layers" | ||
| "i.e. `tie_word_embeddings` in the model config is False." | ||
| ) | ||
|
|
||
| # Liger loss | ||
| if self.use_liger_kernel: | ||
| if not is_liger_kernel_available(): | ||
|
|
@@ -876,7 +894,6 @@ def _get_per_token_logps_and_entropies( | |
| # Divide logits by sampling temperature. | ||
| # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details | ||
| logits = logits / self.temperature | ||
|
|
||
| completion_ids = input_ids_batch[:, -logits_to_keep:] | ||
| logps = selective_log_softmax(logits, completion_ids) # compute logprobs | ||
| all_logps.append(logps) | ||
|
|
@@ -1300,6 +1317,8 @@ def _generate_single_turn(self, prompts: list): | |
| unwrapped_model.to(torch.bfloat16) | ||
| elif self.args.fp16: | ||
| unwrapped_model.to(torch.float16) | ||
| if self.args.cast_lm_head_to_fp32: | ||
| unwrapped_model.lm_head.to(torch.float32) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I need to register the pre-hook here again, let me know if I'm wrong. |
||
| with torch.inference_mode(): | ||
| # Continuous batching API expects 'inputs' arg only | ||
| all_outputs = unwrapped_model.generate_batch( | ||
|
|
||




There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe here it would make sense to assert that
trainer.model.lm_headis indeed in fp32There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call.