diff --git a/trl/models/__init__.py b/trl/models/__init__.py index c697c2c375..fbcfc677da 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -26,6 +26,7 @@ "clone_chat_template", "prepare_deepspeed", "prepare_fsdp", + "prepare_model_for_kbit_training", "prepare_peft_model", "setup_chat_format", "unwrap_model_for_generation", @@ -42,6 +43,7 @@ clone_chat_template, prepare_deepspeed, prepare_fsdp, + prepare_model_for_kbit_training, prepare_peft_model, setup_chat_format, unwrap_model_for_generation, diff --git a/trl/models/utils.py b/trl/models/utils.py index b9e0da6868..276ec752ba 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import itertools import warnings from collections.abc import Callable @@ -32,7 +33,7 @@ if is_peft_available(): import peft - from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model if TYPE_CHECKING: @@ -471,6 +472,51 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. pass +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): + r""" + Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) + quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] + is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( + model, "hqq_quantized", False + ) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + n_upcasted = 0 + for name, param in model.named_parameters(): + # freeze all parameters + param.requires_grad = False + + # upcast LayerNorm / Norm to float32 for numerical stability + if (param.dtype in [torch.float16, torch.bfloat16]) and ( + "norm" in name.lower() or "layernorm" in name.lower() + ): + param.data = param.data.to(torch.float32) + n_upcasted += 1 + + # Enable gradient checkpointing if needed + if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + # backward-compatible hook + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} + model.gradient_checkpointing_enable(**gc_kwargs) + + return model + + def enable_gradient_checkpointing( model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict] ) -> PreTrainedModel: