Skip to content

Conversation

@pramodith
Copy link
Collaborator

What does this PR do?

Per ScaleRL

"FP32 Precision for LLM logits The generators and trainers rely on different kernels for inference and training, leading to small numerical mismatches in their token probabilities (He & Lab, 2025). RL training is highly sensitive to such discrepancies, since they directly affect the IS ratio in the surrogate objective. MiniMax et al. (2025) identified that these mismatches are especially pronounced at the language model head, and mitigate this by FP32 computations at the head for both the generator and trainer. As shown in Figure 5b, the precision fix dramatically improves the asymptotic performance A from 0.52 to 0.61. Given this clear benefit, we include the FP32 precision fix in our ScaleRL recipe."

This paper adds the option to cast the LM head to full-precision.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@pramodith
Copy link
Collaborator Author

pramodith commented Oct 18, 2025

I wasn't able to find a config in vllm that forces the lm head to fp32. Likewise for liger_loss, lemme know if I should just raise Unsupported exceptionsfor whencast_lm_head_to_fp32anduse_vllmoruse_liger_loss` are enabled.

@qgallouedec
Copy link
Member

I would approach this a bit differently: if the user passes a string as the model ID, it means they’re relying on trl to choose the best way to initialize the model. In that case, we can safely cast the lm_head to float32:

GRPOTrainer(
    model="my_model",  # lm_head automatically cast to float32
    ...
)

However, if the user wants more fine-grained control over how the model is initialized, they should handle the initialization themselves, for example:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("my_model")

GRPOTrainer(
    model=model,
    ...
)

consequently, we wouldn't need a new parameter in in GRPOConfig.

  • Note 1: that it's not very consistent with disable_dropout, but I think it's ok.
  • Note 2: it will need to be clearly documented in GRPOTrainer.model

What do you think?

@qgallouedec
Copy link
Member

And also, do you mind adding a section in paper index as well. Something like "currently, only the lm head to fp32 trick is supported and enabled by default"

@pramodith
Copy link
Collaborator Author

pramodith commented Oct 19, 2025

Note 1: that it's not very consistent with disable_dropout, but I think it's ok.
Note 2: it will need to be clearly documented in GRPOTrainer.model

I did consider the approach that you're proposing too and I do see its merit. My concern with the lib automatically setting the LM Head to fp32 without the user explicitly stating it is that the memory consumed by the LM Head and subsequently logits, entropy etc. is going to be a fair amount which could lead to OOM issues.

The user would then need to spend time understanding why there's an OOM. If we document things really well like throwing in a log message stating LM Head is initialized as FP32, if you run into OOM errors consider down-casting this by ... they might be able to identify the reason for OOM quickly but would then have to go on to add the 2 lines of code that you pointed whereas for a config it'd just be one line. They'd also be much more aware of the cause for OOM errors because they explicitly set the config to run the LM Head in fp32.

@pramodith
Copy link
Collaborator Author

And also, do you mind adding a section in paper index as well. Something like "currently, only the lm head to fp32 trick is supported and enabled by default"

I'm going to be adding some of the other parts of ScaleRL too, and was planning on adding ScaleRL to the paper-index once we had all of those along with the CISPO pr merged.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR @pramodith ! Overall LGTM. If time / compute permits, it could be quite interesting to see what training / inference probabilities look like with/without FP32, as done in the MiniMax M1 paper: https://arxiv.org/abs/2506.13585

Image

},
)
cast_lm_head_to_fp32: bool = field(
default=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When/if we confirm the results from the paper, this could be changed to True in my opinion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lemme take a look at the paper that Lewis linked, and get back to you on if I can do something like that with the limited gpu resources I have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fairly easy to implement! All I have to do is take a dataset with a thousand prompts. load the model in bf16 mode, cast its lm_head to fp32 and store the logits of the tokens in the completions and then do the same for a model loaded in vllm.

I can then plot something similar to the Minimax paper. Will have this done in the next few hours.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow so cool, if you have any trouble reproducing, I think it's perfectly fine to open a subsequent PR to change the default

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized that there were a few edge cases, going to try to wrap it up as soon as possible. If I run into issues reproducing in the next couple of days I'll go ahead and merge this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Do you have 1/2 models as example?

Copy link
Collaborator Author

@pramodith pramodith Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any of the Qwen-3 (0.6B, 4B, 8B) or Qwen-2.5 (7.5B) models. It seems like the smaller models have tie_word_embeddings set to True and the larger models (like 32B) do not.

Copy link
Collaborator Author

@pramodith pramodith Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some results from Qwen3-32B (had similar results for the smaller models 4B, 8B, 14B). We see marginal improvement in Correlation and Summed absolute differences in probs

qwen32_bf16 qwen32_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3-30B-A3-Thinking similar results.

qwen_moe_bf16 qwen_moe_fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec I'm gpu resource constrained so if you want to run it on larger MoE models this is the code you might need to tinker with the tensor_parallel_size and enable_expert_parallel.

def test_training_with_cast_lm_head_to_fp32(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
fp16=True,
Copy link
Member

@qgallouedec qgallouedec Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fp16=True,

for consistency with other tests. Unless there is something specific with fp16?

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe here it would make sense to assert that trainer.model.lm_head is indeed in fp32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

elif self.args.fp16:
unwrapped_model.to(torch.float16)
if self.args.cast_lm_head_to_fp32:
unwrapped_model.lm_head.to(torch.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I need to register the pre-hook here again, let me know if I'm wrong.

@pramodith
Copy link
Collaborator Author

@qgallouedec please take another look at this PR and if all's good feel free to merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants