Skip to content

Commit 4352074

Browse files
Use explicit tiny-Qwen2_5_VL model_id parameter in CI tests (#4325)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 928f589 commit 4352074

File tree

3 files changed

+88
-26
lines changed

3 files changed

+88
-26
lines changed

tests/test_grpo_trainer.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,8 +1305,14 @@ def reward_func(completions, **kwargs):
13051305
new_param = trainer.model.get_parameter(n)
13061306
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
13071307

1308+
@pytest.mark.parametrize(
1309+
"model_id",
1310+
[
1311+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1312+
],
1313+
)
13081314
@require_vision
1309-
def test_training_vlm_beta_non_zero(self):
1315+
def test_training_vlm_beta_non_zero(self, model_id):
13101316
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
13111317

13121318
def reward_func(completions, **kwargs):
@@ -1323,7 +1329,7 @@ def reward_func(completions, **kwargs):
13231329
report_to="none",
13241330
)
13251331
trainer = GRPOTrainer(
1326-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1332+
model=model_id,
13271333
reward_funcs=reward_func,
13281334
args=training_args,
13291335
train_dataset=dataset,
@@ -1345,12 +1351,16 @@ def reward_func(completions, **kwargs):
13451351
new_param = trainer.model.get_parameter(n)
13461352
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
13471353

1354+
@pytest.mark.parametrize(
1355+
"model_id",
1356+
[
1357+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1358+
],
1359+
)
13481360
@require_vision
13491361
@require_peft
1350-
def test_training_vlm_peft(self):
1351-
model = AutoModelForImageTextToText.from_pretrained(
1352-
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
1353-
)
1362+
def test_training_vlm_peft(self, model_id):
1363+
model = AutoModelForImageTextToText.from_pretrained(model_id)
13541364
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
13551365
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
13561366

@@ -1388,8 +1398,14 @@ def reward_func(completions, **kwargs):
13881398
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
13891399
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
13901400

1401+
@pytest.mark.parametrize(
1402+
"model_id",
1403+
[
1404+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1405+
],
1406+
)
13911407
@require_vision
1392-
def test_training_vlm_and_importance_sampling(self):
1408+
def test_training_vlm_and_importance_sampling(self, model_id):
13931409
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
13941410

13951411
def reward_func(completions, **kwargs):
@@ -1406,7 +1422,7 @@ def reward_func(completions, **kwargs):
14061422
report_to="none",
14071423
)
14081424
trainer = GRPOTrainer(
1409-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1425+
model=model_id,
14101426
reward_funcs=reward_func,
14111427
args=training_args,
14121428
train_dataset=dataset,
@@ -1428,9 +1444,15 @@ def reward_func(completions, **kwargs):
14281444
new_param = trainer.model.get_parameter(n)
14291445
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
14301446

1447+
@pytest.mark.parametrize(
1448+
"model_id",
1449+
[
1450+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1451+
],
1452+
)
14311453
@require_vision
14321454
@require_liger_kernel
1433-
def test_training_vlm_and_liger(self):
1455+
def test_training_vlm_and_liger(self, model_id):
14341456
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
14351457

14361458
def reward_func(completions, **kwargs):
@@ -1448,7 +1470,7 @@ def reward_func(completions, **kwargs):
14481470
report_to="none",
14491471
)
14501472
trainer = GRPOTrainer(
1451-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1473+
model=model_id,
14521474
reward_funcs=reward_func,
14531475
args=training_args,
14541476
train_dataset=dataset,
@@ -1515,8 +1537,14 @@ def reward_func(completions, **kwargs):
15151537
new_param = trainer.model.get_parameter(n)
15161538
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
15171539

1540+
@pytest.mark.parametrize(
1541+
"model_id",
1542+
[
1543+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1544+
],
1545+
)
15181546
@require_vision
1519-
def test_training_vlm_multi_image(self):
1547+
def test_training_vlm_multi_image(self, model_id):
15201548
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")
15211549

15221550
def reward_func(completions, **kwargs):
@@ -1533,7 +1561,7 @@ def reward_func(completions, **kwargs):
15331561
report_to="none",
15341562
)
15351563
trainer = GRPOTrainer(
1536-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1564+
model=model_id,
15371565
reward_funcs=reward_func,
15381566
args=training_args,
15391567
train_dataset=dataset,

tests/test_rloo_trainer.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,8 +1129,14 @@ def reward_func(completions, **kwargs):
11291129
new_param = trainer.model.get_parameter(n)
11301130
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
11311131

1132+
@pytest.mark.parametrize(
1133+
"model_id",
1134+
[
1135+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1136+
],
1137+
)
11321138
@require_vision
1133-
def test_training_vlm_beta_non_zero(self):
1139+
def test_training_vlm_beta_non_zero(self, model_id):
11341140
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
11351141

11361142
def reward_func(completions, **kwargs):
@@ -1147,7 +1153,7 @@ def reward_func(completions, **kwargs):
11471153
report_to="none",
11481154
)
11491155
trainer = RLOOTrainer(
1150-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1156+
model=model_id,
11511157
reward_funcs=reward_func,
11521158
args=training_args,
11531159
train_dataset=dataset,
@@ -1169,12 +1175,16 @@ def reward_func(completions, **kwargs):
11691175
new_param = trainer.model.get_parameter(n)
11701176
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
11711177

1178+
@pytest.mark.parametrize(
1179+
"model_id",
1180+
[
1181+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1182+
],
1183+
)
11721184
@require_vision
11731185
@require_peft
1174-
def test_training_vlm_peft(self):
1175-
model = AutoModelForImageTextToText.from_pretrained(
1176-
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
1177-
)
1186+
def test_training_vlm_peft(self, model_id):
1187+
model = AutoModelForImageTextToText.from_pretrained(model_id)
11781188
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
11791189
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
11801190

@@ -1257,8 +1267,14 @@ def reward_func(completions, **kwargs):
12571267
new_param = trainer.model.get_parameter(n)
12581268
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
12591269

1270+
@pytest.mark.parametrize(
1271+
"model_id",
1272+
[
1273+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1274+
],
1275+
)
12601276
@require_vision
1261-
def test_training_vlm_multi_image(self):
1277+
def test_training_vlm_multi_image(self, model_id):
12621278
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")
12631279

12641280
def reward_func(completions, **kwargs):
@@ -1275,7 +1291,7 @@ def reward_func(completions, **kwargs):
12751291
report_to="none",
12761292
)
12771293
trainer = RLOOTrainer(
1278-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1294+
model=model_id,
12791295
reward_funcs=reward_func,
12801296
args=training_args,
12811297
train_dataset=dataset,

tests/test_sft_trainer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,13 +1381,19 @@ def test_train_vlm(self, model_id):
13811381
continue
13821382
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
13831383

1384+
@pytest.mark.parametrize(
1385+
"model_id",
1386+
[
1387+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1388+
],
1389+
)
13841390
@pytest.mark.xfail(
13851391
parse_version(transformers.__version__) < parse_version("4.57.0"),
13861392
reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0",
13871393
strict=False,
13881394
)
13891395
@require_vision
1390-
def test_train_vlm_multi_image(self):
1396+
def test_train_vlm_multi_image(self, model_id):
13911397
# Get the dataset
13921398
dataset = load_dataset(
13931399
"trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train"
@@ -1400,7 +1406,7 @@ def test_train_vlm_multi_image(self):
14001406
report_to="none",
14011407
)
14021408
trainer = SFTTrainer(
1403-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1409+
model=model_id,
14041410
args=training_args,
14051411
train_dataset=dataset,
14061412
)
@@ -1419,8 +1425,14 @@ def test_train_vlm_multi_image(self):
14191425
new_param = trainer.model.get_parameter(n)
14201426
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
14211427

1428+
@pytest.mark.parametrize(
1429+
"model_id",
1430+
[
1431+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1432+
],
1433+
)
14221434
@require_vision
1423-
def test_train_vlm_prompt_completion(self):
1435+
def test_train_vlm_prompt_completion(self, model_id):
14241436
# Get the dataset
14251437
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
14261438

@@ -1431,7 +1443,7 @@ def test_train_vlm_prompt_completion(self):
14311443
report_to="none",
14321444
)
14331445
trainer = SFTTrainer(
1434-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1446+
model=model_id,
14351447
args=training_args,
14361448
train_dataset=dataset,
14371449
)
@@ -1520,15 +1532,21 @@ def test_train_vlm_gemma_3n(self):
15201532
continue
15211533
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
15221534

1535+
@pytest.mark.parametrize(
1536+
"model_id",
1537+
[
1538+
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1539+
],
1540+
)
15231541
@require_vision
1524-
def test_train_vlm_text_only_data(self):
1542+
def test_train_vlm_text_only_data(self, model_id):
15251543
# Get the dataset
15261544
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
15271545

15281546
# Initialize the trainer
15291547
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
15301548
trainer = SFTTrainer(
1531-
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1549+
model=model_id,
15321550
args=training_args,
15331551
train_dataset=dataset,
15341552
)

0 commit comments

Comments
 (0)