diff --git a/tests/experimental/test_grpo_with_replay_buffer_trainer.py b/tests/experimental/test_grpo_with_replay_buffer_trainer.py new file mode 100644 index 0000000000..cad66f8034 --- /dev/null +++ b/tests/experimental/test_grpo_with_replay_buffer_trainer.py @@ -0,0 +1,276 @@ +import pytest +import torch +from datasets import load_dataset + +from trl import GRPOTrainer +from trl.experimental.grpo_with_replay_buffer import ( + GRPOWithReplayBufferConfig, + GRPOWithReplayBufferTrainer, + ReplayBuffer, +) + +from ..testing_utils import TrlTestCase + + +@pytest.mark.low_priority +class TestReplayBuffer: + def setup_method(self): + self.replay_buffer = ReplayBuffer(max_size=5) + + def test_add(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + ] + self.replay_buffer.add(scores, data) + + # Check if the buffer contains the correct number of elements + assert len(self.replay_buffer.heap) == 5 + + # Check if the buffer maintains the min-heap property + heap_scores = [item[0] for item in self.replay_buffer.heap] + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.3 + + def test_add_more_than_maxlen(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7, 0.6, 0.4] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + {"id": 6}, + {"id": 7}, + ] + self.replay_buffer.add(scores, data) + + # Check if the buffer contains the correct number of elements + assert len(self.replay_buffer.heap) == 5 + + # Check if the buffer maintains the min-heap property + heap_scores = [item[0] for item in self.replay_buffer.heap] + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed + + def test_sample(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + ] + self.replay_buffer.add(scores, data) + + # Sample elements from the buffer + sampled = self.replay_buffer.sample(num_samples=3) + + # Check if the sampled elements are from the buffer + assert len(sampled) == 3 + for item in sampled: + assert item in [entry[1] for entry in self.replay_buffer.heap] + + +@pytest.mark.low_priority +class TestUpdateWithReplayBuffer: + def setup_method(self): + config = GRPOWithReplayBufferConfig( + replay_buffer_size=5, + ) + self.trainer = GRPOWithReplayBufferTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=config, + train_dataset=None, + ) + self.trainer.replay_buffer = ReplayBuffer(max_size=5) + self.trainer.num_generations = 2 + + def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False): + scores = [0.1, 0.9] + data = [ + { + "prompt_ids": torch.tensor([[100, 101], [102, 103]]), + "prompt_mask": torch.ones(2, 2, dtype=torch.long), + "completion_ids": torch.tensor([[5, 6], [7, 8]]), + "completion_mask": torch.ones(2, 2, dtype=torch.long), + "advantages": torch.tensor([[0.5, 0.6]]), + **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), + **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), + }, + { + "prompt_ids": torch.tensor([[104, 105], [106, 107]]), + "prompt_mask": torch.ones(2, 2, dtype=torch.long), + "completion_ids": torch.tensor([[13, 14], [15, 16]]), + "completion_mask": torch.ones(2, 2, dtype=torch.long), + "advantages": torch.tensor([[0.8, 0.85]]), + **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), + **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), + }, + ] + self.trainer.replay_buffer.add(scores, data) + + def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False): + inputs = { + "group_advantages": group_advantages, + "prompt_ids": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), + "prompt_mask": torch.ones(4, 2, dtype=torch.long), + "completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]), + "completion_mask": torch.ones(4, 2, dtype=torch.long), + "prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {}, + "old_per_token_logps": torch.randn(4, 2) if with_logprobs else None, + } + inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) + return inputs + + def test_update_with_replay_buffer_no_variance(self): + self._prepopulate_buffer(with_pixels=True, with_logprobs=True) + group_advantages = torch.tensor([[0.5, 0.5], [0.8, 0.8]]) # no variance + inputs = self._make_inputs(group_advantages, with_pixels=True, with_logprobs=True) + original_prompt_ids = inputs["prompt_ids"].clone() + + outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert outputs is not None + assert "pixel_values" in outputs + assert "old_per_token_logps" in outputs + assert len(self.trainer.replay_buffer.heap) == 2 + for pid in outputs["prompt_ids"]: + assert pid.tolist() not in original_prompt_ids.tolist() + + def test_update_with_replay_buffer_with_variance(self): + self._prepopulate_buffer() + group_advantages = torch.tensor([[0.6, 0.4], [0.7, 1.2]]) # has variance + inputs = self._make_inputs(group_advantages) + + sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert len(self.trainer.replay_buffer.heap) == 4 # grew + assert sampled is None + + def test_update_with_mixed_variance(self): + self._prepopulate_buffer() + group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance + inputs = self._make_inputs(group_advantages) + original_prompt_ids = inputs["prompt_ids"].clone().view(-1, self.trainer.num_generations, 2).tolist() + + outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1 + output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist() + + buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap] + found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids) + found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids) + + assert found_from_buffer + assert found_from_original + assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group + + def test_update_with_inputs_different_seq_len(self): + """ + Test with inputs where the sequence lengths are different from the prepopulated buffer. + """ + self._prepopulate_buffer() + pad_token_id = self.trainer.processing_class.pad_token_id + group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance + inputs = { + "group_advantages": group_advantages, + "prompt_ids": torch.tensor( + [ + [1, 2, pad_token_id], + [1, 2, pad_token_id], + [3, 4, 5], + [3, 4, 5], + ] + ), + "prompt_mask": torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1], [1, 1, 1]], dtype=torch.long), + "completion_ids": torch.tensor( + [ + [1009, 1010, pad_token_id], + [1011, 1012, 1013], + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] + ), + "completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long), + "prompt_inputs": {}, + } + inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) + + outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + # Seq length of current batch should be preserved + assert outputs_after_sampling["prompt_ids"].shape[-1] == 3 + assert len(self.trainer.replay_buffer.heap) == 3 + output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist() + + buffered_prompt_completion_ids = [ + (item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist()) + for item in self.trainer.replay_buffer.heap + ] + buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids) + + # Check for new entry with seq len 3 in buffer + assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group + assert [ + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] in buffered_completion_ids # excluded no-variance group + + # Check that sampled outputs contain one group with prompt_ids starting with a pad token + assert [ + [pad_token_id, 101, 102], + [pad_token_id, 102, 103], + ] in output_prompt_ids or [ + [pad_token_id, 104, 105], + [pad_token_id, 106, 107], + ] in output_prompt_ids + + +@pytest.mark.low_priority +class TestGRPOWithReplayBufferTrainer(TrlTestCase): + def test_training_with_replay_buffer(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Guarantee that some rewards have 0 std + def custom_reward_func(completions, **kwargs): + if torch.rand(1).item() < 0.25: + return [0] * len(completions) # simulate some None rewards + else: + return torch.rand(len(completions)).tolist() + + training_args = GRPOWithReplayBufferConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=4, # reduce the batch size to reduce memory usage + num_generations=4, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + replay_buffer_size=8, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[custom_reward_func], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 50efa9e45c..d2b364c1fc 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -26,11 +26,6 @@ from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer -from trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig -from trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer import ( - GRPOWithReplayBufferTrainer, - ReplayBuffer, -) from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer from .testing_utils import TrlTestCase, require_liger_kernel, require_peft, require_vision, require_vllm @@ -1666,270 +1661,6 @@ def test_single_reward_model_with_single_processing_class(self): assert trainer.reward_processing_classes[0] == single_processing_class -@pytest.mark.low_priority -class TestReplayBuffer: - def setup_method(self): - self.replay_buffer = ReplayBuffer(max_size=5) - - def test_add(self): - # Add elements to the replay buffer - scores = [0.5, 0.8, 0.3, 0.9, 0.7] - data = [ - {"id": 1}, - {"id": 2}, - {"id": 3}, - {"id": 4}, - {"id": 5}, - ] - self.replay_buffer.add(scores, data) - - # Check if the buffer contains the correct number of elements - assert len(self.replay_buffer.heap) == 5 - - # Check if the buffer maintains the min-heap property - heap_scores = [item[0] for item in self.replay_buffer.heap] - assert heap_scores[0] == min(heap_scores) - assert heap_scores[0] == 0.3 - - def test_add_more_than_maxlen(self): - # Add elements to the replay buffer - scores = [0.5, 0.8, 0.3, 0.9, 0.7, 0.6, 0.4] - data = [ - {"id": 1}, - {"id": 2}, - {"id": 3}, - {"id": 4}, - {"id": 5}, - {"id": 6}, - {"id": 7}, - ] - self.replay_buffer.add(scores, data) - - # Check if the buffer contains the correct number of elements - assert len(self.replay_buffer.heap) == 5 - - # Check if the buffer maintains the min-heap property - heap_scores = [item[0] for item in self.replay_buffer.heap] - assert heap_scores[0] == min(heap_scores) - assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed - - def test_sample(self): - # Add elements to the replay buffer - scores = [0.5, 0.8, 0.3, 0.9, 0.7] - data = [ - {"id": 1}, - {"id": 2}, - {"id": 3}, - {"id": 4}, - {"id": 5}, - ] - self.replay_buffer.add(scores, data) - - # Sample elements from the buffer - sampled = self.replay_buffer.sample(num_samples=3) - - # Check if the sampled elements are from the buffer - assert len(sampled) == 3 - for item in sampled: - assert item in [entry[1] for entry in self.replay_buffer.heap] - - -@pytest.mark.low_priority -class TestUpdateWithReplayBuffer: - def setup_method(self): - config = GRPOWithReplayBufferConfig( - replay_buffer_size=5, - ) - self.trainer = GRPOWithReplayBufferTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", - args=config, - train_dataset=None, - ) - self.trainer.replay_buffer = ReplayBuffer(max_size=5) - self.trainer.num_generations = 2 - - def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False): - scores = [0.1, 0.9] - data = [ - { - "prompt_ids": torch.tensor([[100, 101], [102, 103]]), - "prompt_mask": torch.ones(2, 2, dtype=torch.long), - "completion_ids": torch.tensor([[5, 6], [7, 8]]), - "completion_mask": torch.ones(2, 2, dtype=torch.long), - "advantages": torch.tensor([[0.5, 0.6]]), - **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), - **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), - }, - { - "prompt_ids": torch.tensor([[104, 105], [106, 107]]), - "prompt_mask": torch.ones(2, 2, dtype=torch.long), - "completion_ids": torch.tensor([[13, 14], [15, 16]]), - "completion_mask": torch.ones(2, 2, dtype=torch.long), - "advantages": torch.tensor([[0.8, 0.85]]), - **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), - **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), - }, - ] - self.trainer.replay_buffer.add(scores, data) - - def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False): - inputs = { - "group_advantages": group_advantages, - "prompt_ids": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), - "prompt_mask": torch.ones(4, 2, dtype=torch.long), - "completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]), - "completion_mask": torch.ones(4, 2, dtype=torch.long), - "prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {}, - "old_per_token_logps": torch.randn(4, 2) if with_logprobs else None, - } - inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) - return inputs - - def test_update_with_replay_buffer_no_variance(self): - self._prepopulate_buffer(with_pixels=True, with_logprobs=True) - group_advantages = torch.tensor([[0.5, 0.5], [0.8, 0.8]]) # no variance - inputs = self._make_inputs(group_advantages, with_pixels=True, with_logprobs=True) - original_prompt_ids = inputs["prompt_ids"].clone() - - outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - - assert outputs is not None - assert "pixel_values" in outputs - assert "old_per_token_logps" in outputs - assert len(self.trainer.replay_buffer.heap) == 2 - for pid in outputs["prompt_ids"]: - assert pid.tolist() not in original_prompt_ids.tolist() - - def test_update_with_replay_buffer_with_variance(self): - self._prepopulate_buffer() - group_advantages = torch.tensor([[0.6, 0.4], [0.7, 1.2]]) # has variance - inputs = self._make_inputs(group_advantages) - - sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - - assert len(self.trainer.replay_buffer.heap) == 4 # grew - assert sampled is None - - def test_update_with_mixed_variance(self): - self._prepopulate_buffer() - group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance - inputs = self._make_inputs(group_advantages) - original_prompt_ids = inputs["prompt_ids"].clone().view(-1, self.trainer.num_generations, 2).tolist() - - outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - - assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1 - output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist() - - buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap] - found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids) - found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids) - - assert found_from_buffer - assert found_from_original - assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group - - def test_update_with_inputs_different_seq_len(self): - """ - Test with inputs where the sequence lengths are different from the prepopulated buffer. - """ - self._prepopulate_buffer() - pad_token_id = self.trainer.processing_class.pad_token_id - group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance - inputs = { - "group_advantages": group_advantages, - "prompt_ids": torch.tensor( - [ - [1, 2, pad_token_id], - [1, 2, pad_token_id], - [3, 4, 5], - [3, 4, 5], - ] - ), - "prompt_mask": torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1], [1, 1, 1]], dtype=torch.long), - "completion_ids": torch.tensor( - [ - [1009, 1010, pad_token_id], - [1011, 1012, 1013], - [1013, 1014, pad_token_id], - [1015, 1016, 1017], - ] - ), - "completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long), - "prompt_inputs": {}, - } - inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) - - outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - # Seq length of current batch should be preserved - assert outputs_after_sampling["prompt_ids"].shape[-1] == 3 - assert len(self.trainer.replay_buffer.heap) == 3 - output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist() - - buffered_prompt_completion_ids = [ - (item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist()) - for item in self.trainer.replay_buffer.heap - ] - buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids) - - # Check for new entry with seq len 3 in buffer - assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group - assert [ - [1013, 1014, pad_token_id], - [1015, 1016, 1017], - ] in buffered_completion_ids # excluded no-variance group - - # Check that sampled outputs contain one group with prompt_ids starting with a pad token - assert [ - [pad_token_id, 101, 102], - [pad_token_id, 102, 103], - ] in output_prompt_ids or [ - [pad_token_id, 104, 105], - [pad_token_id, 106, 107], - ] in output_prompt_ids - - -@pytest.mark.low_priority -class TestGRPOWithReplayBufferTrainer(TrlTestCase): - def test_training_with_replay_buffer(self): - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - - # Guarantee that some rewards have 0 std - def custom_reward_func(completions, **kwargs): - if torch.rand(1).item() < 0.25: - return [0] * len(completions) # simulate some None rewards - else: - return torch.rand(len(completions)).tolist() - - training_args = GRPOWithReplayBufferConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=4, # reduce the batch size to reduce memory usage - num_generations=4, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - replay_buffer_size=8, - report_to="none", - ) - trainer = GRPOTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_funcs=[custom_reward_func], - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check that the params have changed - for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - - class TestGSPOTokenTrainer(TrlTestCase): def test_training(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")