Skip to content

Commit a9d33d0

Browse files
kaixuanliuqgallouedecalbertvillanova
authored
fix CI issue for vlm_gemma_3n model (#4278)
Signed-off-by: Liu, Kaixuan <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Albert Villanova del Moral <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 34fdb61 commit a9d33d0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/test_sft_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,7 @@ def test_train_vlm_gemma_3n(self):
14941494
# Initialize the trainer
14951495
training_args = SFTConfig(
14961496
output_dir=self.tmp_dir,
1497+
learning_rate=0.1,
14971498
max_length=None,
14981499
per_device_train_batch_size=1,
14991500
gradient_checkpointing=True,
@@ -1514,8 +1515,8 @@ def test_train_vlm_gemma_3n(self):
15141515
# Check the params have changed
15151516
for n, param in previous_trainable_params.items():
15161517
new_param = trainer.model.get_parameter(n)
1517-
if "model.vision_tower" in n:
1518-
# The vision tower is not updated, not sure why at this point.
1518+
if "model.audio_tower" in n or "model.embed_audio" in n:
1519+
# The audio embedding parameters are not updated because this dataset contains no audio data
15191520
continue
15201521
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
15211522

0 commit comments

Comments
 (0)