Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,8 +1305,14 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_training_vlm_beta_non_zero(self):
def test_training_vlm_beta_non_zero(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1323,7 +1329,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand All @@ -1345,12 +1351,16 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
@require_peft
def test_training_vlm_peft(self):
model = AutoModelForImageTextToText.from_pretrained(
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
def test_training_vlm_peft(self, model_id):
model = AutoModelForImageTextToText.from_pretrained(model_id)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

Expand Down Expand Up @@ -1388,8 +1398,14 @@ def reward_func(completions, **kwargs):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_training_vlm_and_importance_sampling(self):
def test_training_vlm_and_importance_sampling(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1406,7 +1422,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand All @@ -1428,9 +1444,15 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
@require_liger_kernel
def test_training_vlm_and_liger(self):
def test_training_vlm_and_liger(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1448,7 +1470,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand Down Expand Up @@ -1515,8 +1537,14 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_training_vlm_multi_image(self):
def test_training_vlm_multi_image(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1533,7 +1561,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand Down
32 changes: 24 additions & 8 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,8 +1129,14 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_training_vlm_beta_non_zero(self):
def test_training_vlm_beta_non_zero(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1147,7 +1153,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand All @@ -1169,12 +1175,16 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
@require_peft
def test_training_vlm_peft(self):
model = AutoModelForImageTextToText.from_pretrained(
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
def test_training_vlm_peft(self, model_id):
model = AutoModelForImageTextToText.from_pretrained(model_id)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

Expand Down Expand Up @@ -1257,8 +1267,14 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_training_vlm_multi_image(self):
def test_training_vlm_multi_image(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1275,7 +1291,7 @@ def reward_func(completions, **kwargs):
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
Expand Down
30 changes: 24 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,13 +1381,19 @@ def test_train_vlm(self, model_id):
continue
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@pytest.mark.xfail(
parse_version(transformers.__version__) < parse_version("4.57.0"),
reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0",
strict=False,
)
@require_vision
def test_train_vlm_multi_image(self):
def test_train_vlm_multi_image(self, model_id):
# Get the dataset
dataset = load_dataset(
"trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train"
Expand All @@ -1400,7 +1406,7 @@ def test_train_vlm_multi_image(self):
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
args=training_args,
train_dataset=dataset,
)
Expand All @@ -1419,8 +1425,14 @@ def test_train_vlm_multi_image(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_train_vlm_prompt_completion(self):
def test_train_vlm_prompt_completion(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")

Expand All @@ -1431,7 +1443,7 @@ def test_train_vlm_prompt_completion(self):
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1519,15 +1531,21 @@ def test_train_vlm_gemma_3n(self):
continue
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
],
)
@require_vision
def test_train_vlm_text_only_data(self):
def test_train_vlm_text_only_data(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

# Initialize the trainer
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
args=training_args,
train_dataset=dataset,
)
Expand Down
Loading