|  | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
|  | 15 | +import inspect | 
| 15 | 16 | import itertools | 
| 16 | 17 | import warnings | 
| 17 | 18 | from collections.abc import Callable | 
|  | 
| 32 | 33 | 
 | 
| 33 | 34 | if is_peft_available(): | 
| 34 | 35 |     import peft | 
| 35 |  | -    from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training | 
|  | 36 | +    from peft import PeftConfig, PeftModel, get_peft_model | 
| 36 | 37 | 
 | 
| 37 | 38 | 
 | 
| 38 | 39 | if TYPE_CHECKING: | 
| @@ -471,6 +472,51 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. | 
| 471 | 472 |         pass | 
| 472 | 473 | 
 | 
| 473 | 474 | 
 | 
|  | 475 | +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): | 
|  | 476 | +    r""" | 
|  | 477 | +    Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). | 
|  | 478 | +    """ | 
|  | 479 | +    loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) | 
|  | 480 | +    quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] | 
|  | 481 | +    is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( | 
|  | 482 | +        model, "hqq_quantized", False | 
|  | 483 | +    ) | 
|  | 484 | + | 
|  | 485 | +    if gradient_checkpointing_kwargs is None: | 
|  | 486 | +        gradient_checkpointing_kwargs = {} | 
|  | 487 | + | 
|  | 488 | +    n_upcasted = 0 | 
|  | 489 | +    for name, param in model.named_parameters(): | 
|  | 490 | +        # freeze all parameters | 
|  | 491 | +        param.requires_grad = False | 
|  | 492 | + | 
|  | 493 | +        # upcast LayerNorm / Norm to float32 for numerical stability | 
|  | 494 | +        if (param.dtype in [torch.float16, torch.bfloat16]) and ( | 
|  | 495 | +            "norm" in name.lower() or "layernorm" in name.lower() | 
|  | 496 | +        ): | 
|  | 497 | +            param.data = param.data.to(torch.float32) | 
|  | 498 | +            n_upcasted += 1 | 
|  | 499 | + | 
|  | 500 | +    # Enable gradient checkpointing if needed | 
|  | 501 | +    if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: | 
|  | 502 | +        if hasattr(model, "enable_input_require_grads"): | 
|  | 503 | +            model.enable_input_require_grads() | 
|  | 504 | +        else: | 
|  | 505 | +            # backward-compatible hook | 
|  | 506 | +            def make_inputs_require_grad(module, input, output): | 
|  | 507 | +                output.requires_grad_(True) | 
|  | 508 | + | 
|  | 509 | +            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) | 
|  | 510 | + | 
|  | 511 | +        supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( | 
|  | 512 | +            inspect.signature(model.gradient_checkpointing_enable).parameters | 
|  | 513 | +        ) | 
|  | 514 | +        gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} | 
|  | 515 | +        model.gradient_checkpointing_enable(**gc_kwargs) | 
|  | 516 | + | 
|  | 517 | +    return model | 
|  | 518 | + | 
|  | 519 | + | 
| 474 | 520 | def enable_gradient_checkpointing( | 
| 475 | 521 |     model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict] | 
| 476 | 522 | ) -> PreTrainedModel: | 
|  | 
0 commit comments