Skip to content

Commit cb9bc2a

Browse files
authored
🚚 Move BCO to trl.experimental (#4312)
1 parent 475c732 commit cb9bc2a

File tree

12 files changed

+1782
-1704
lines changed

12 files changed

+1782
-1704
lines changed

docs/source/_toctree.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@
6060
title: Examples
6161
- sections:
6262
- sections: # Sorted alphabetically
63-
- local: bco_trainer
64-
title: BCO
6563
- local: cpo_trainer
6664
title: CPO
6765
- local: dpo_trainer
@@ -108,3 +106,7 @@
108106
- local: others
109107
title: Others
110108
title: API
109+
- sections:
110+
- local: bco_trainer
111+
title: BCO
112+
title: Experimental

docs/source/bco_trainer.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ For a full example have a look at [`examples/scripts/bco.py`].
88

99
## Expected dataset type
1010

11-
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
12-
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
11+
The [`experimental.bco.BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
12+
The [`experimental.bco.BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
1313

1414
## Expected model format
1515

@@ -93,11 +93,11 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
9393

9494
## BCOTrainer
9595

96-
[[autodoc]] BCOTrainer
96+
[[autodoc]] experimental.bco.BCOTrainer
9797
- train
9898
- save_model
9999
- push_to_hub
100100

101101
## BCOConfig
102102

103-
[[autodoc]] BCOConfig
103+
[[autodoc]] experimental.bco.BCOConfig

docs/source/dataset_formats.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ Choosing the right dataset type depends on the task you are working on and the s
389389

390390
| Trainer | Expected dataset type |
391391
| --- | --- |
392-
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
392+
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
393393
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
394394
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
395395
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |

docs/source/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
88
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
99

10-
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
10+
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support; 🧪 = experimental).
1111

1212
## Taxonomy
1313

@@ -36,7 +36,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
3636
- [`SFTTrainer`]
3737
- [`DPOTrainer`]
3838
- [`ORPOTrainer`]
39-
- [`BCOTrainer`]
39+
- [`experimental.bco.BCOTrainer`] 🧪
4040
- [`CPOTrainer`]
4141
- [`KTOTrainer`]
4242

docs/source/paper_index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ training_args = DPOConfig(
338338
)
339339
```
340340

341-
For the unpaired version, the user should utilize [`BCOConfig`] and [`BCOTrainer`].
341+
For the unpaired version, the user should utilize [`experimental.bco.BCOConfig`] and [`experimental.bco.BCOTrainer`].
342342

343343
### Self-Play Preference Optimization for Language Model Alignment
344344

examples/scripts/bco.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@
8585
from datasets import load_dataset
8686
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel
8787

88-
from trl import BCOConfig, BCOTrainer, ModelConfig, ScriptArguments, get_peft_config
88+
from trl import ModelConfig, ScriptArguments, get_peft_config
89+
from trl.experimental.bco import BCOConfig, BCOTrainer
8990

9091

9192
# Enable logging in a Hugging Face Space

tests/test_bco_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
2222
from transformers.utils import is_peft_available
2323

24-
from trl import BCOConfig, BCOTrainer
25-
from trl.trainer.bco_trainer import _process_tokens, _tokenize
24+
from trl.experimental.bco import BCOConfig, BCOTrainer
25+
from trl.experimental.bco.bco_trainer import _process_tokens, _tokenize
2626

2727
from .testing_utils import TrlTestCase, require_no_wandb, require_peft, require_sklearn
2828

@@ -31,6 +31,7 @@
3131
from peft import LoraConfig
3232

3333

34+
@pytest.mark.low_priority
3435
class TestBCOTrainer(TrlTestCase):
3536
@pytest.mark.parametrize(
3637
"config_name",

trl/experimental/bco/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .bco_config import BCOConfig
16+
from .bco_trainer import BCOTrainer

trl/experimental/bco/bco_config.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
from typing import Any, Optional
17+
18+
from transformers import TrainingArguments
19+
20+
21+
@dataclass
22+
class BCOConfig(TrainingArguments):
23+
r"""
24+
Configuration class for the [`BCOTrainer`].
25+
26+
This class includes only the parameters that are specific to BCO training. For a full list of training arguments,
27+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
28+
differ from those in [`~transformers.TrainingArguments`].
29+
30+
Using [`~transformers.HfArgumentParser`] we can turn this class into
31+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
32+
command line.
33+
34+
Parameters:
35+
max_length (`int` or `None`, *optional*, defaults to `1024`):
36+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
37+
to use the default data collator.
38+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
39+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
40+
max_completion_length (`int`, *optional*):
41+
Maximum length of the completion. This argument is required if you want to use the default data collator
42+
and your model is an encoder-decoder.
43+
beta (`float`, *optional*, defaults to `0.1`):
44+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
45+
reference model.
46+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
47+
Label pad token id. This argument is required if you want to use the default data collator.
48+
padding_value (`int`, *optional*):
49+
Padding value to use. If `None`, the padding value of the tokenizer is used.
50+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
51+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
52+
This argument is required if you want to use the default data collator.
53+
disable_dropout (`bool`, *optional*, defaults to `True`):
54+
Whether to disable dropout in the model and reference model.
55+
generate_during_eval (`bool`, *optional*, defaults to `False`):
56+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
57+
during evaluation.
58+
is_encoder_decoder (`bool`, *optional*):
59+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
60+
you need to specify if the model returned by the callable is an encoder-decoder model.
61+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
62+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
63+
useful when training without the reference model to reduce the total GPU memory needed.
64+
model_init_kwargs (`dict[str, Any]`, *optional*):
65+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
66+
string.
67+
ref_model_init_kwargs (`dict[str, Any]`, *optional*):
68+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
69+
from a string.
70+
dataset_num_proc (`int`, *optional*):
71+
Number of processes to use for processing the dataset.
72+
prompt_sample_size (`int`, *optional*, defaults to `1024`):
73+
Number of prompts that are fed to density ratio classifier.
74+
min_density_ratio (`float`, *optional*, defaults to `0.5`):
75+
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
76+
max_density_ratio (`float`, *optional*, defaults to `10.0`):
77+
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
78+
"""
79+
80+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"]
81+
82+
# Parameters whose default values are overridden from TrainingArguments
83+
logging_steps: float = field(
84+
default=10,
85+
metadata={
86+
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
87+
"will be interpreted as ratio of total training steps."
88+
},
89+
)
90+
gradient_checkpointing: bool = field(
91+
default=True,
92+
metadata={
93+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
94+
},
95+
)
96+
bf16: Optional[bool] = field(
97+
default=None,
98+
metadata={
99+
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
100+
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
101+
"`fp16` is not set."
102+
},
103+
)
104+
105+
max_length: Optional[int] = field(
106+
default=1024,
107+
metadata={
108+
"help": "Maximum length of the sequences (prompt + completion) in the batch. "
109+
"This argument is required if you want to use the default data collator."
110+
},
111+
)
112+
max_prompt_length: Optional[int] = field(
113+
default=512,
114+
metadata={
115+
"help": "Maximum length of the prompt. "
116+
"This argument is required if you want to use the default data collator."
117+
},
118+
)
119+
max_completion_length: Optional[int] = field(
120+
default=None,
121+
metadata={
122+
"help": "Maximum length of the completion. This argument is required if you want to use the "
123+
"default data collator and your model is an encoder-decoder."
124+
},
125+
)
126+
beta: float = field(
127+
default=0.1,
128+
metadata={
129+
"help": "Parameter controlling the deviation from the reference model. "
130+
"Higher β means less deviation from the reference model."
131+
},
132+
)
133+
label_pad_token_id: int = field(
134+
default=-100,
135+
metadata={
136+
"help": "Label pad token id. This argument is required if you want to use the default data collator."
137+
},
138+
)
139+
padding_value: Optional[int] = field(
140+
default=None,
141+
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
142+
)
143+
truncation_mode: str = field(
144+
default="keep_end",
145+
metadata={
146+
"help": "Truncation mode to use when the prompt is too long. Possible values are "
147+
"`keep_end` or `keep_start`. This argument is required if you want to use the "
148+
"default data collator."
149+
},
150+
)
151+
disable_dropout: bool = field(
152+
default=True,
153+
metadata={"help": "Whether to disable dropout in the model and reference model."},
154+
)
155+
generate_during_eval: bool = field(
156+
default=False,
157+
metadata={
158+
"help": "If `True`, generates and logs completions from both the model and the reference model "
159+
"to W&B during evaluation."
160+
},
161+
)
162+
is_encoder_decoder: Optional[bool] = field(
163+
default=None,
164+
metadata={
165+
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the "
166+
"`model` argument, you need to specify if the model returned by the callable is an "
167+
"encoder-decoder model."
168+
},
169+
)
170+
precompute_ref_log_probs: bool = field(
171+
default=False,
172+
metadata={
173+
"help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
174+
"This is useful when training without the reference model to reduce the total GPU memory "
175+
"needed."
176+
},
177+
)
178+
model_init_kwargs: Optional[dict[str, Any]] = field(
179+
default=None,
180+
metadata={
181+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
182+
"model from a string."
183+
},
184+
)
185+
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
186+
default=None,
187+
metadata={
188+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
189+
"reference model from a string."
190+
},
191+
)
192+
dataset_num_proc: Optional[int] = field(
193+
default=None,
194+
metadata={"help": "Number of processes to use for processing the dataset."},
195+
)
196+
prompt_sample_size: int = field(
197+
default=1024,
198+
metadata={"help": "Number of prompts that are fed to density ratio classifier."},
199+
)
200+
min_density_ratio: float = field(
201+
default=0.5,
202+
metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."},
203+
)
204+
max_density_ratio: float = field(
205+
default=10.0,
206+
metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."},
207+
)
208+
209+
def __post_init__(self):
210+
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
211+
212+
super().__post_init__()

0 commit comments

Comments
 (0)