|  | 
|  | 1 | +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +from dataclasses import dataclass, field | 
|  | 16 | +from typing import Any, Optional | 
|  | 17 | + | 
|  | 18 | +from transformers import TrainingArguments | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +@dataclass | 
|  | 22 | +class BCOConfig(TrainingArguments): | 
|  | 23 | +    r""" | 
|  | 24 | +    Configuration class for the [`BCOTrainer`]. | 
|  | 25 | +
 | 
|  | 26 | +    This class includes only the parameters that are specific to BCO training. For a full list of training arguments, | 
|  | 27 | +    please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may | 
|  | 28 | +    differ from those in [`~transformers.TrainingArguments`]. | 
|  | 29 | +
 | 
|  | 30 | +    Using [`~transformers.HfArgumentParser`] we can turn this class into | 
|  | 31 | +    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the | 
|  | 32 | +    command line. | 
|  | 33 | +
 | 
|  | 34 | +    Parameters: | 
|  | 35 | +        max_length (`int` or `None`, *optional*, defaults to `1024`): | 
|  | 36 | +            Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want | 
|  | 37 | +            to use the default data collator. | 
|  | 38 | +        max_prompt_length (`int` or `None`, *optional*, defaults to `512`): | 
|  | 39 | +            Maximum length of the prompt. This argument is required if you want to use the default data collator. | 
|  | 40 | +        max_completion_length (`int`, *optional*): | 
|  | 41 | +            Maximum length of the completion. This argument is required if you want to use the default data collator | 
|  | 42 | +            and your model is an encoder-decoder. | 
|  | 43 | +        beta (`float`, *optional*, defaults to `0.1`): | 
|  | 44 | +            Parameter controlling the deviation from the reference model. Higher β means less deviation from the | 
|  | 45 | +            reference model. | 
|  | 46 | +        label_pad_token_id (`int`,  *optional*, defaults to `-100`): | 
|  | 47 | +            Label pad token id. This argument is required if you want to use the default data collator. | 
|  | 48 | +        padding_value (`int`, *optional*): | 
|  | 49 | +            Padding value to use. If `None`, the padding value of the tokenizer is used. | 
|  | 50 | +        truncation_mode (`str`, *optional*, defaults to `"keep_end"`): | 
|  | 51 | +            Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. | 
|  | 52 | +            This argument is required if you want to use the default data collator. | 
|  | 53 | +        disable_dropout (`bool`, *optional*, defaults to `True`): | 
|  | 54 | +            Whether to disable dropout in the model and reference model. | 
|  | 55 | +        generate_during_eval (`bool`, *optional*, defaults to `False`): | 
|  | 56 | +            If `True`, generates and logs completions from both the model and the reference model to W&B or Comet | 
|  | 57 | +            during evaluation. | 
|  | 58 | +        is_encoder_decoder (`bool`, *optional*): | 
|  | 59 | +            When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, | 
|  | 60 | +            you need to specify if the model returned by the callable is an encoder-decoder model. | 
|  | 61 | +        precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): | 
|  | 62 | +            Whether to precompute reference model log probabilities for training and evaluation datasets. This is | 
|  | 63 | +            useful when training without the reference model to reduce the total GPU memory needed. | 
|  | 64 | +        model_init_kwargs (`dict[str, Any]`, *optional*): | 
|  | 65 | +            Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a | 
|  | 66 | +            string. | 
|  | 67 | +        ref_model_init_kwargs (`dict[str, Any]`, *optional*): | 
|  | 68 | +            Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model | 
|  | 69 | +            from a string. | 
|  | 70 | +        dataset_num_proc (`int`, *optional*): | 
|  | 71 | +            Number of processes to use for processing the dataset. | 
|  | 72 | +        prompt_sample_size (`int`, *optional*, defaults to `1024`): | 
|  | 73 | +            Number of prompts that are fed to density ratio classifier. | 
|  | 74 | +        min_density_ratio (`float`, *optional*, defaults to `0.5`): | 
|  | 75 | +            Minimum value of the density ratio. The estimated density ratio is clamped to this value. | 
|  | 76 | +        max_density_ratio (`float`, *optional*, defaults to `10.0`): | 
|  | 77 | +            Maximum value of the density ratio. The estimated density ratio is clamped to this value. | 
|  | 78 | +    """ | 
|  | 79 | + | 
|  | 80 | +    _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] | 
|  | 81 | + | 
|  | 82 | +    # Parameters whose default values are overridden from TrainingArguments | 
|  | 83 | +    logging_steps: float = field( | 
|  | 84 | +        default=10, | 
|  | 85 | +        metadata={ | 
|  | 86 | +            "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " | 
|  | 87 | +            "will be interpreted as ratio of total training steps." | 
|  | 88 | +        }, | 
|  | 89 | +    ) | 
|  | 90 | +    gradient_checkpointing: bool = field( | 
|  | 91 | +        default=True, | 
|  | 92 | +        metadata={ | 
|  | 93 | +            "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." | 
|  | 94 | +        }, | 
|  | 95 | +    ) | 
|  | 96 | +    bf16: Optional[bool] = field( | 
|  | 97 | +        default=None, | 
|  | 98 | +        metadata={ | 
|  | 99 | +            "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " | 
|  | 100 | +            "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " | 
|  | 101 | +            "`fp16` is not set." | 
|  | 102 | +        }, | 
|  | 103 | +    ) | 
|  | 104 | + | 
|  | 105 | +    max_length: Optional[int] = field( | 
|  | 106 | +        default=1024, | 
|  | 107 | +        metadata={ | 
|  | 108 | +            "help": "Maximum length of the sequences (prompt + completion) in the batch. " | 
|  | 109 | +            "This argument is required if you want to use the default data collator." | 
|  | 110 | +        }, | 
|  | 111 | +    ) | 
|  | 112 | +    max_prompt_length: Optional[int] = field( | 
|  | 113 | +        default=512, | 
|  | 114 | +        metadata={ | 
|  | 115 | +            "help": "Maximum length of the prompt. " | 
|  | 116 | +            "This argument is required if you want to use the default data collator." | 
|  | 117 | +        }, | 
|  | 118 | +    ) | 
|  | 119 | +    max_completion_length: Optional[int] = field( | 
|  | 120 | +        default=None, | 
|  | 121 | +        metadata={ | 
|  | 122 | +            "help": "Maximum length of the completion. This argument is required if you want to use the " | 
|  | 123 | +            "default data collator and your model is an encoder-decoder." | 
|  | 124 | +        }, | 
|  | 125 | +    ) | 
|  | 126 | +    beta: float = field( | 
|  | 127 | +        default=0.1, | 
|  | 128 | +        metadata={ | 
|  | 129 | +            "help": "Parameter controlling the deviation from the reference model. " | 
|  | 130 | +            "Higher β means less deviation from the reference model." | 
|  | 131 | +        }, | 
|  | 132 | +    ) | 
|  | 133 | +    label_pad_token_id: int = field( | 
|  | 134 | +        default=-100, | 
|  | 135 | +        metadata={ | 
|  | 136 | +            "help": "Label pad token id. This argument is required if you want to use the default data collator." | 
|  | 137 | +        }, | 
|  | 138 | +    ) | 
|  | 139 | +    padding_value: Optional[int] = field( | 
|  | 140 | +        default=None, | 
|  | 141 | +        metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, | 
|  | 142 | +    ) | 
|  | 143 | +    truncation_mode: str = field( | 
|  | 144 | +        default="keep_end", | 
|  | 145 | +        metadata={ | 
|  | 146 | +            "help": "Truncation mode to use when the prompt is too long. Possible values are " | 
|  | 147 | +            "`keep_end` or `keep_start`. This argument is required if you want to use the " | 
|  | 148 | +            "default data collator." | 
|  | 149 | +        }, | 
|  | 150 | +    ) | 
|  | 151 | +    disable_dropout: bool = field( | 
|  | 152 | +        default=True, | 
|  | 153 | +        metadata={"help": "Whether to disable dropout in the model and reference model."}, | 
|  | 154 | +    ) | 
|  | 155 | +    generate_during_eval: bool = field( | 
|  | 156 | +        default=False, | 
|  | 157 | +        metadata={ | 
|  | 158 | +            "help": "If `True`, generates and logs completions from both the model and the reference model " | 
|  | 159 | +            "to W&B during evaluation." | 
|  | 160 | +        }, | 
|  | 161 | +    ) | 
|  | 162 | +    is_encoder_decoder: Optional[bool] = field( | 
|  | 163 | +        default=None, | 
|  | 164 | +        metadata={ | 
|  | 165 | +            "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " | 
|  | 166 | +            "`model` argument, you need to specify if the model returned by the callable is an " | 
|  | 167 | +            "encoder-decoder model." | 
|  | 168 | +        }, | 
|  | 169 | +    ) | 
|  | 170 | +    precompute_ref_log_probs: bool = field( | 
|  | 171 | +        default=False, | 
|  | 172 | +        metadata={ | 
|  | 173 | +            "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " | 
|  | 174 | +            "This is useful when training without the reference model to reduce the total GPU memory " | 
|  | 175 | +            "needed." | 
|  | 176 | +        }, | 
|  | 177 | +    ) | 
|  | 178 | +    model_init_kwargs: Optional[dict[str, Any]] = field( | 
|  | 179 | +        default=None, | 
|  | 180 | +        metadata={ | 
|  | 181 | +            "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " | 
|  | 182 | +            "model from a string." | 
|  | 183 | +        }, | 
|  | 184 | +    ) | 
|  | 185 | +    ref_model_init_kwargs: Optional[dict[str, Any]] = field( | 
|  | 186 | +        default=None, | 
|  | 187 | +        metadata={ | 
|  | 188 | +            "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " | 
|  | 189 | +            "reference model from a string." | 
|  | 190 | +        }, | 
|  | 191 | +    ) | 
|  | 192 | +    dataset_num_proc: Optional[int] = field( | 
|  | 193 | +        default=None, | 
|  | 194 | +        metadata={"help": "Number of processes to use for processing the dataset."}, | 
|  | 195 | +    ) | 
|  | 196 | +    prompt_sample_size: int = field( | 
|  | 197 | +        default=1024, | 
|  | 198 | +        metadata={"help": "Number of prompts that are fed to density ratio classifier."}, | 
|  | 199 | +    ) | 
|  | 200 | +    min_density_ratio: float = field( | 
|  | 201 | +        default=0.5, | 
|  | 202 | +        metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."}, | 
|  | 203 | +    ) | 
|  | 204 | +    max_density_ratio: float = field( | 
|  | 205 | +        default=10.0, | 
|  | 206 | +        metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."}, | 
|  | 207 | +    ) | 
|  | 208 | + | 
|  | 209 | +    def __post_init__(self): | 
|  | 210 | +        self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 | 
|  | 211 | + | 
|  | 212 | +        super().__post_init__() | 
0 commit comments