Skip to content

Commit 6f14353

Browse files
pacman100younesbelkadaBenjaminBossan
authored
Speed up the peft lora unload (#5741)
* Update peft_utils.py * fix bug * make the util backwards compatible. Co-Authored-By: Younes Belkada <[email protected]> * fix import issue * refactor the backward compatibilty condition * rename the conditional variable * address comments Co-Authored-By: Benjamin Bossan <[email protected]> * address comment --------- Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent c6f90da commit 6f14353

File tree

1 file changed

+66
-44
lines changed

1 file changed

+66
-44
lines changed

src/diffusers/utils/peft_utils.py

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,55 +23,77 @@
2323
from .import_utils import is_peft_available, is_torch_available
2424

2525

26-
def recurse_remove_peft_layers(model):
27-
if is_torch_available():
28-
import torch
26+
if is_torch_available():
27+
import torch
28+
2929

30+
def recurse_remove_peft_layers(model):
3031
r"""
3132
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
3233
"""
33-
from peft.tuners.lora import LoraLayer
34-
35-
for name, module in model.named_children():
36-
if len(list(module.children())) > 0:
37-
## compound module, go inside it
38-
recurse_remove_peft_layers(module)
39-
40-
module_replaced = False
41-
42-
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
43-
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
44-
module.weight.device
45-
)
46-
new_module.weight = module.weight
47-
if module.bias is not None:
48-
new_module.bias = module.bias
49-
50-
module_replaced = True
51-
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
52-
new_module = torch.nn.Conv2d(
53-
module.in_channels,
54-
module.out_channels,
55-
module.kernel_size,
56-
module.stride,
57-
module.padding,
58-
module.dilation,
59-
module.groups,
60-
).to(module.weight.device)
61-
62-
new_module.weight = module.weight
63-
if module.bias is not None:
64-
new_module.bias = module.bias
65-
66-
module_replaced = True
67-
68-
if module_replaced:
69-
setattr(model, name, new_module)
70-
del module
71-
72-
if torch.cuda.is_available():
73-
torch.cuda.empty_cache()
34+
from peft.tuners.tuners_utils import BaseTunerLayer
7435

36+
has_base_layer_pattern = False
37+
for module in model.modules():
38+
if isinstance(module, BaseTunerLayer):
39+
has_base_layer_pattern = hasattr(module, "base_layer")
40+
break
41+
42+
if has_base_layer_pattern:
43+
from peft.utils import _get_submodules
44+
45+
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
46+
for key in key_list:
47+
try:
48+
parent, target, target_name = _get_submodules(model, key)
49+
except AttributeError:
50+
continue
51+
if hasattr(target, "base_layer"):
52+
setattr(parent, target_name, target.get_base_layer())
53+
else:
54+
# This is for backwards compatibility with PEFT <= 0.6.2.
55+
# TODO can be removed once that PEFT version is no longer supported.
56+
from peft.tuners.lora import LoraLayer
57+
58+
for name, module in model.named_children():
59+
if len(list(module.children())) > 0:
60+
## compound module, go inside it
61+
recurse_remove_peft_layers(module)
62+
63+
module_replaced = False
64+
65+
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
66+
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
67+
module.weight.device
68+
)
69+
new_module.weight = module.weight
70+
if module.bias is not None:
71+
new_module.bias = module.bias
72+
73+
module_replaced = True
74+
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
75+
new_module = torch.nn.Conv2d(
76+
module.in_channels,
77+
module.out_channels,
78+
module.kernel_size,
79+
module.stride,
80+
module.padding,
81+
module.dilation,
82+
module.groups,
83+
).to(module.weight.device)
84+
85+
new_module.weight = module.weight
86+
if module.bias is not None:
87+
new_module.bias = module.bias
88+
89+
module_replaced = True
90+
91+
if module_replaced:
92+
setattr(model, name, new_module)
93+
del module
94+
95+
if torch.cuda.is_available():
96+
torch.cuda.empty_cache()
7597
return model
7698

7799

0 commit comments

Comments
 (0)