- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.3k
GRPO: ScaleRL -> Support casting LM Head to FP32 #4303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
GRPO: ScaleRL -> Support casting LM Head to FP32 #4303
Conversation
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
| I wasn't able to find a config in  | 
| I would approach this a bit differently: if the user passes a string as the model ID, it means they’re relying on trl to choose the best way to initialize the model. In that case, we can safely cast the  GRPOTrainer(
    model="my_model",  # lm_head automatically cast to float32
    ...
)However, if the user wants more fine-grained control over how the model is initialized, they should handle the initialization themselves, for example: from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("my_model")
GRPOTrainer(
    model=model,
    ...
)consequently, we wouldn't need a new parameter in in  
 What do you think? | 
| And also, do you mind adding a section in paper index as well. Something like "currently, only the lm head to fp32 trick is supported and enabled by default" | 
| 
 I did consider the approach that you're proposing too and I do see its merit. My concern with the lib automatically setting the LM Head to fp32 without the user explicitly stating it is that the memory consumed by the LM Head and subsequently logits, entropy etc. is going to be a fair amount which could lead to OOM issues. The user would then need to spend time understanding why there's an OOM. If we document things really well like throwing in a log message stating  | 
| 
 I'm going to be adding some of the other parts of ScaleRL too, and was planning on adding ScaleRL to the paper-index once we had all of those along with the CISPO pr merged. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean PR @pramodith ! Overall LGTM. If time / compute permits, it could be quite interesting to see what training / inference probabilities look like with/without FP32, as done in the MiniMax M1 paper: https://arxiv.org/abs/2506.13585
 
      
    | }, | ||
| ) | ||
| cast_lm_head_to_fp32: bool = field( | ||
| default=False, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When/if we confirm the results from the paper, this could be changed to True in my opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm.
I can then plot something similar to the Minimax paper. Will have this done in the next few hours.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! Do you have 1/2 models as example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have tie_word_embeddings set to True and the larger models (like 32B) do not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the tensor_parallel_size and enable_expert_parallel.
        
          
                tests/test_grpo_trainer.py
              
                Outdated
          
        
      | def test_training_with_cast_lm_head_to_fp32(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
| training_args = GRPOConfig( | ||
| fp16=True, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| fp16=True, | 
for consistency with other tests. Unless there is something specific with fp16?
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|  | ||
| trainer.train() | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe here it would make sense to assert that trainer.model.lm_head is indeed in fp32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call.
| elif self.args.fp16: | ||
| unwrapped_model.to(torch.float16) | ||
| if self.args.cast_lm_head_to_fp32: | ||
| unwrapped_model.lm_head.to(torch.float32) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I need to register the pre-hook here again, let me know if I'm wrong.
| @qgallouedec please take another look at this PR and if all's good feel free to merge it. | 




What does this PR do?
Per ScaleRL
This paper adds the option to cast the LM head to full-precision.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.