-
Couldn't load subscription status.
- Fork 2.3k
Fix CI issue for vlm_gemma_3n model #4278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
|
@kashif , pls help review, thx very much. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch and the fix.
I confirm that this PR fixes the test:
PASSED tests/test_sft_trainer.py::TestSFTTrainer::test_train_vlm_gemma_3nMaybe we could add the reason why vision/audio towers do not update, something like they are frozen during training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer keeping bf16, as it's usually the precision used. Instead we can simply increase precision
tests/test_sft_trainer.py
Outdated
| per_device_train_batch_size=1, | ||
| gradient_checkpointing=True, | ||
| model_init_kwargs={"dtype": "bfloat16"}, | ||
| model_init_kwargs={"dtype": "float16"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_init_kwargs={"dtype": "float16"}, | |
| model_init_kwargs={"dtype": "bfloat16"}, |
| # Initialize the trainer | ||
| training_args = SFTConfig( | ||
| output_dir=self.tmp_dir, | ||
| max_length=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| max_length=None, | |
| learning_rate=0.1, # increase lr to ensure updates are not lost due to bf16 rounding | |
| max_length=None, |
Co-authored-by: Albert Villanova del Moral <[email protected]>
When we run test case
pytest -rA tests/test_sft_trainer.py::TestSFTTrainer::test_train_vlm_gemma_3n, it will fail both on CUDA and Intel XPU. Further investigation shows there are 2 reasons:This PR fixes this bug.