Skip to content

Commit fda88c6

Browse files
Added custom prepare_model_for_kbit_training to save VRAM (#4335)
Co-authored-by: Kashif Rasul <[email protected]>
1 parent 2a138c7 commit fda88c6

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

trl/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"clone_chat_template",
2727
"prepare_deepspeed",
2828
"prepare_fsdp",
29+
"prepare_model_for_kbit_training",
2930
"prepare_peft_model",
3031
"setup_chat_format",
3132
"unwrap_model_for_generation",
@@ -42,6 +43,7 @@
4243
clone_chat_template,
4344
prepare_deepspeed,
4445
prepare_fsdp,
46+
prepare_model_for_kbit_training,
4547
prepare_peft_model,
4648
setup_chat_format,
4749
unwrap_model_for_generation,

trl/models/utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import itertools
1617
import warnings
1718
from collections.abc import Callable
@@ -32,7 +33,7 @@
3233

3334
if is_peft_available():
3435
import peft
35-
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
36+
from peft import PeftConfig, PeftModel, get_peft_model
3637

3738

3839
if TYPE_CHECKING:
@@ -471,6 +472,51 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.
471472
pass
472473

473474

475+
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
476+
r"""
477+
Prepare a k-bit quantized transformers model for training (PEFT/QLoRA).
478+
"""
479+
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
480+
quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"]
481+
is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr(
482+
model, "hqq_quantized", False
483+
)
484+
485+
if gradient_checkpointing_kwargs is None:
486+
gradient_checkpointing_kwargs = {}
487+
488+
n_upcasted = 0
489+
for name, param in model.named_parameters():
490+
# freeze all parameters
491+
param.requires_grad = False
492+
493+
# upcast LayerNorm / Norm to float32 for numerical stability
494+
if (param.dtype in [torch.float16, torch.bfloat16]) and (
495+
"norm" in name.lower() or "layernorm" in name.lower()
496+
):
497+
param.data = param.data.to(torch.float32)
498+
n_upcasted += 1
499+
500+
# Enable gradient checkpointing if needed
501+
if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing:
502+
if hasattr(model, "enable_input_require_grads"):
503+
model.enable_input_require_grads()
504+
else:
505+
# backward-compatible hook
506+
def make_inputs_require_grad(module, input, output):
507+
output.requires_grad_(True)
508+
509+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
510+
511+
supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
512+
inspect.signature(model.gradient_checkpointing_enable).parameters
513+
)
514+
gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {}
515+
model.gradient_checkpointing_enable(**gc_kwargs)
516+
517+
return model
518+
519+
474520
def enable_gradient_checkpointing(
475521
model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict]
476522
) -> PreTrainedModel:

0 commit comments

Comments
 (0)