Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
566ec1d
[WIP] Make it possible to mix adapter types
BenjaminBossan Nov 1, 2023
7a9e906
Make style
BenjaminBossan Nov 1, 2023
461c530
Use old type annotation style
BenjaminBossan Nov 1, 2023
0f90b32
Support more adapters
BenjaminBossan Nov 2, 2023
642ae09
Merge branch 'main' into mixed-adapter-types
BenjaminBossan Nov 3, 2023
6706ea4
Extend tests, some small fixes
BenjaminBossan Nov 3, 2023
f9c4d54
Fix isinstance check with Union
BenjaminBossan Nov 3, 2023
6f34b50
More type annotation shenanigans
BenjaminBossan Nov 3, 2023
ac1b3e4
Call set_adapter more consistently
BenjaminBossan Nov 7, 2023
ec2164b
Fix some issues with AdaLora, LoHa, LoKr
BenjaminBossan Nov 7, 2023
da119cc
Fix regression with setting the active adapter
BenjaminBossan Nov 7, 2023
0a2381a
Small test refactor, add active_adapter property
BenjaminBossan Nov 7, 2023
95adb3a
Make style
BenjaminBossan Nov 7, 2023
41b3fe5
Adjustments to testing
BenjaminBossan Nov 8, 2023
d91bfcd
Fix test
BenjaminBossan Nov 8, 2023
f55c5a3
Add tests for using the same adapter type twice
BenjaminBossan Nov 8, 2023
00531a4
Support merge_and_unload in mixed models
BenjaminBossan Nov 8, 2023
b342eeb
Add support for unloading
BenjaminBossan Nov 8, 2023
81748e2
Add tests for disabling adapters
BenjaminBossan Nov 8, 2023
bce8319
Make style -_-
BenjaminBossan Nov 8, 2023
6722f42
Fix a bug that prevented nesting more than 1 level
BenjaminBossan Nov 8, 2023
cbf1ca9
Remove everything related to mixed adapters
BenjaminBossan Nov 9, 2023
30c75a8
Refactor merge_and_unload to use base_layer
BenjaminBossan Nov 9, 2023
5de0654
Merge branch 'main' into refactor-base-layer-pattern
BenjaminBossan Nov 10, 2023
f54c4a3
Merge branch 'main' into refactor-base-layer-pattern
BenjaminBossan Nov 14, 2023
9e80efb
Address reviewer comments
BenjaminBossan Nov 14, 2023
b3662f4
Simplify merging code for IA3
BenjaminBossan Nov 14, 2023
eac560a
Merge branch 'main' into refactor-base-layer-pattern
BenjaminBossan Nov 16, 2023
dc5f282
Make style
BenjaminBossan Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,3 @@
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit

# Mapping of tuners that support direct plugging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not being used anywhere. Removing this is not related to this PR specifically.

TUNERS_MAPPING = {
"LORA": LoraModel,
"IA3": IA3Model,
"ADALORA": AdaLoraModel,
}
56 changes: 18 additions & 38 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import bitsandbytes as bnb
import torch

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
Expand All @@ -23,38 +22,28 @@

if is_bnb_available():

class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer):
class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.get_base_layer().weight.requires_grad = False

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
# note: no check for self.merged because merging is not supported (yet)
result = self.base_layer.forward(x)

if self.disable_adapters:
return result
Expand Down Expand Up @@ -85,37 +74,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

if is_bnb_4bit_available():

class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer):
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be breaking no? Some setups might consider some checks such as isinstance(module, bnb.nn.Linear4bit), suddenly these layers will not inherit anymore from bnb layers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically true, although I'm not sure how much we can consider inheritance structure of our classes to be "public API". I think we can assume that very few users do this type of check (if any), and I hope those few expert users can quickly figure out what is wrong. If you are aware of anyone doing this, let me know and I can prepare something like a migration guide.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sounds great yeah, I think this is too much of an edge case, let's just be careful on the release notes and maybe add this PR on the next minor release (0.7.0)

# Low-rank matrix for SVD-based adaptation
def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer: torch.nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward compatibility we should maybe consider keeping the same arg order, ignore the unused kwargs and put the new arguments at the end. This might be too much of an edge case but we never know. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as above, though here we could change the order of adapter_name and base_layer, which would, however, be inconsistent with the other layers. Nevertheless, even if we change the order, a caller would still get an error here because we now have an additional argument, the base_layer, that they don't pass. So either way, there would be an error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed!

adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.get_base_layer().weight.requires_grad = False

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
# note: no check for self.merged because merging is not supported (yet)
result = self.base_layer.forward(x)

if self.disable_adapters:
return result
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/adalora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is not needed anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_adapter is already called inside of update_layer just above, so it was in fact called twice in a row.


def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.quant_linear_module(x)
Expand Down
74 changes: 42 additions & 32 deletions src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.

import warnings
from typing import Any

import torch
import torch.nn.functional as F
from torch import nn
from transformers.pytorch_utils import Conv1D

from peft.tuners.lora import LoraLayer
from peft.utils import transpose
Expand All @@ -29,17 +30,25 @@ class AdaLoraLayer(LoraLayer):
adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
# other_param_names is defined in LoraLayer

def __init__(
self,
in_features: int,
out_features: int,
):
super().__init__(in_features, out_features)
def __init__(self, base_layer: nn.Module) -> None:
super().__init__(base_layer)
self.lora_E = nn.ParameterDict({})
self.lora_A = nn.ParameterDict({})
self.lora_B = nn.ParameterDict({})
self.ranknum = nn.ParameterDict({})

base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
self.in_features = in_features
self.out_features = out_features

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
Expand All @@ -63,7 +72,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters)

def reset_lora_parameters(self, adapter_name):
Expand All @@ -73,32 +82,31 @@ def reset_lora_parameters(self, adapter_name):
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)


class SVDLinear(nn.Linear, AdaLoraLayer):
class SVDLinear(nn.Module, AdaLoraLayer):
# SVD-based adaptation by a dense layer
def __init__(
self,
base_layer: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment above about BC, I would slowly deprecate them instead of changing the signature of these public classes

adapter_name: str,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
**kwargs,
) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True)
nn.Linear.__init__(self, in_features, out_features, **kwargs)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.get_base_layer().weight.requires_grad = False

self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T

nn.Linear.reset_parameters(self)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)

@property
def weight(self) -> torch.Tensor:
return self.get_base_layer().weight

def merge(self, safe_merge: bool = False) -> None:
"""
Expand All @@ -116,21 +124,22 @@ def merge(self, safe_merge: bool = False) -> None:
f"You are now additionally merging {','.join(self.active_adapters)}."
)
for active_adapter in self.active_adapters:
base_layer = self.get_base_layer()
if active_adapter in self.lora_A.keys():
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = self.weight.data.clone()
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.weight.data = orig_weights
base_layer.weight.data = orig_weights
else:
self.weight.data += self.get_delta_weight(active_adapter)
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -140,7 +149,7 @@ def unmerge(self) -> None:
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.weight.data -= self.get_delta_weight(active_adapter)
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
return (
Expand All @@ -149,19 +158,16 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
/ (self.ranknum[adapter] + 1e-5)
)

def _linear(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct?
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self._linear(x)
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self._linear(x)
result = self.base_layer(x, *args, **kwargs)
else:
result = self._linear(x)
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -176,8 +182,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return result

def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
Comment on lines +180 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you maybe add few comments here elaborating:

During debugging, I got very annoyed with the fact that the reprs of adapter layers and normal PyTorch layers are hard to distinguish, e.g. the type is just "Linear". Now, for adapter layers, it is prefixed by the adapter type, e.g. "lora.Linear".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes, I could. The only issue is that we have this method now more than a dozen times, so I would have to add the same comment very often. If you still think it's good to have, or have a better idea, LMK and I'll change it!



class RankAllocator(object):
class RankAllocator:
"""
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY

Expand Down
35 changes: 17 additions & 18 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
_freeze_adapter,
Expand Down Expand Up @@ -121,7 +122,7 @@ def _create_and_replace(
loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False)
if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
kwargs = {
Expand All @@ -138,7 +139,7 @@ def _create_and_replace(
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config

# If it is not a LoraLayer, create a new module, else update it with new adapters
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
Expand All @@ -159,11 +160,16 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

bias = target.bias is not None
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)

if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
bias = target_base_layer.bias is not None

if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
Expand All @@ -172,8 +178,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
"index": target.index,
}
)
new_module = SVDLinear8bitLt(adapter_name, target.in_features, target.out_features, bias=bias, **kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
Expand All @@ -182,25 +188,18 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
"quant_type": target.weight.quant_type,
}
)
new_module = SVDLinear4bit(
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
new_module = SVDQuantLinear(adapter_name, target, **kwargs)
target.weight = target.qweight
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
Expand All @@ -212,7 +211,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs)
new_module = SVDLinear(target, adapter_name, bias=bias, **kwargs)

return new_module

Expand Down
Loading