- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.3k
GRPO: ScaleRL -> Support casting LM Head to FP32 #4303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
f8c9185
              5156a63
              a216558
              6b80f7d
              d97e13c
              abced5a
              1b19f69
              ba565bb
              ef22d09
              6972ed2
              0f258e0
              198d9a1
              bad967e
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -40,9 +40,11 @@ class GRPOConfig(TrainingArguments): | |
| disable_dropout (`bool`, *optional*, defaults to `False`): | ||
| Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents | ||
| the model from generating different logprobs for the same input. | ||
| cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`): | ||
| Whether to cast the Language Modeling Head of the policy and reference models to float32. As recommended by | ||
|         
                  qgallouedec marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. | ||
|  | ||
| > Parameters that control the data preprocessing | ||
|  | ||
| remove_unused_columns (`bool`, *optional*, defaults to `False`): | ||
| Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that | ||
| requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. | ||
|  | @@ -286,6 +288,13 @@ class GRPOConfig(TrainingArguments): | |
| "it prevents the model from generating different logprobs for the same input." | ||
| }, | ||
| ) | ||
| cast_lm_head_to_fp32: bool = field( | ||
| default=False, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When/if we confirm the results from the paper, this could be changed to True in my opinion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm. I can then plot something similar to the Minimax paper. Will have this done in the next few hours. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! Do you have 1/2 models as example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the  | ||
| metadata={ | ||
| "help": "Whether to cast the Language Modeling Head of the policy and reference, models to float32." | ||
|         
                  qgallouedec marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| "As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe." | ||
| }, | ||
| ) | ||
|  | ||
| # Parameters that control the data preprocessing | ||
| # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -443,6 +443,12 @@ def __init__( | |
| if self.ref_model is not None: | ||
| disable_dropout_in_model(self.ref_model) | ||
|  | ||
| # Cast LM Head To FP32 | ||
| if args.cast_lm_head_to_fp32: | ||
| model.lm_head = model.lm_head.float() | ||
| if self.ref_model is not None: | ||
| self.ref_model.lm_head = self.ref_model.lm_head.float() | ||
|  | ||
| # Liger loss | ||
| if self.use_liger_loss: | ||
| if not is_liger_kernel_available(): | ||
|  | @@ -842,7 +848,6 @@ def _get_per_token_logps_and_entropies( | |
| # Divide logits by sampling temperature. | ||
| # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details | ||
| logits = logits / self.temperature | ||
|  | ||
| completion_ids = input_ids_batch[:, -logits_to_keep:] | ||
| logps = selective_log_softmax(logits, completion_ids) # compute logprobs | ||
| all_logps.append(logps) | ||
|  | @@ -1249,6 +1254,8 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): | |
| unwrapped_model.to(torch.bfloat16) | ||
| elif self.args.fp16: | ||
| unwrapped_model.to(torch.float16) | ||
| if self.args.cast_lm_head_to_fp32: | ||
| unwrapped_model.lm_head.to(torch.float32) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I need to register the pre-hook here again, let me know if I'm wrong. | ||
| with torch.inference_mode(): | ||
| all_outputs = unwrapped_model.generate_batch( | ||
| paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False | ||
|  | ||




Uh oh!
There was an error while loading. Please reload this page.