-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor base layer pattern #1106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
566ec1d
7a9e906
461c530
0f90b32
642ae09
6706ea4
f9c4d54
6f34b50
ac1b3e4
ec2164b
da119cc
0a2381a
95adb3a
41b3fe5
d91bfcd
f55c5a3
00531a4
b342eeb
81748e2
bce8319
6722f42
cbf1ca9
30c75a8
5de0654
f54c4a3
9e80efb
b3662f4
eac560a
dc5f282
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import bitsandbytes as bnb | ||
| import torch | ||
|
|
||
| from peft.import_utils import is_bnb_4bit_available, is_bnb_available | ||
|
|
@@ -23,38 +22,28 @@ | |
|
|
||
| if is_bnb_available(): | ||
|
|
||
| class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer): | ||
| class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer): | ||
| # Low-rank matrix for SVD-based adaptation | ||
| def __init__( | ||
| self, | ||
| adapter_name, | ||
| in_features, | ||
| out_features, | ||
| base_layer: torch.nn.Module, | ||
| adapter_name: str, | ||
| r: int = 0, | ||
| lora_alpha: int = 1, | ||
| lora_dropout: float = 0.0, | ||
| init_lora_weights: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| bnb.nn.Linear8bitLt.__init__( | ||
| self, | ||
| in_features, | ||
| out_features, | ||
| bias=kwargs.get("bias", True), | ||
| has_fp16_weights=kwargs.get("has_fp16_weights", True), | ||
| memory_efficient_backward=kwargs.get("memory_efficient_backward", False), | ||
| threshold=kwargs.get("threshold", 0.0), | ||
| index=kwargs.get("index", None), | ||
| ) | ||
| AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) | ||
| super().__init__() | ||
| AdaLoraLayer.__init__(self, base_layer) | ||
| # Freezing the pre-trained weight matrix | ||
| self.weight.requires_grad = False | ||
| self.get_base_layer().weight.requires_grad = False | ||
|
|
||
| init_lora_weights = kwargs.pop("init_lora_weights", True) | ||
| self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | ||
| self.set_adapter(adapter_name) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| result = super().forward(x) | ||
| # note: no check for self.merged because merging is not supported (yet) | ||
| result = self.base_layer.forward(x) | ||
|
|
||
| if self.disable_adapters: | ||
| return result | ||
|
|
@@ -85,37 +74,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
| if is_bnb_4bit_available(): | ||
|
|
||
| class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer): | ||
| class SVDLinear4bit(torch.nn.Module, AdaLoraLayer): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be breaking no? Some setups might consider some checks such as
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically true, although I'm not sure how much we can consider inheritance structure of our classes to be "public API". I think we can assume that very few users do this type of check (if any), and I hope those few expert users can quickly figure out what is wrong. If you are aware of anyone doing this, let me know and I can prepare something like a migration guide.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok sounds great yeah, I think this is too much of an edge case, let's just be careful on the release notes and maybe add this PR on the next minor release (0.7.0) |
||
| # Low-rank matrix for SVD-based adaptation | ||
| def __init__( | ||
| self, | ||
| adapter_name, | ||
| in_features, | ||
| out_features, | ||
| base_layer: torch.nn.Module, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For backward compatibility we should maybe consider keeping the same arg order, ignore the unused kwargs and put the new arguments at the end. This might be too much of an edge case but we never know. What do you think?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment as above, though here we could change the order of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed! |
||
| adapter_name: str, | ||
| r: int = 0, | ||
| lora_alpha: int = 1, | ||
| lora_dropout: float = 0.0, | ||
| init_lora_weights: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| bnb.nn.Linear4bit.__init__( | ||
| self, | ||
| in_features, | ||
| out_features, | ||
| bias=kwargs.get("bias", True), | ||
| compute_dtype=kwargs.get("compute_dtype", torch.float32), | ||
| compress_statistics=kwargs.get("compress_statistics", True), | ||
| quant_type=kwargs.get("quant_type", "nf4"), | ||
| ) | ||
| AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) | ||
| super().__init__() | ||
| AdaLoraLayer.__init__(self, base_layer) | ||
| # Freezing the pre-trained weight matrix | ||
| self.weight.requires_grad = False | ||
| self.get_base_layer().weight.requires_grad = False | ||
|
|
||
| init_lora_weights = kwargs.pop("init_lora_weights", True) | ||
| self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | ||
| self.set_adapter(adapter_name) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| result = super().forward(x) | ||
| # note: no check for self.merged because merging is not supported (yet) | ||
| result = self.base_layer.forward(x) | ||
|
|
||
| if self.disable_adapters: | ||
| return result | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,6 @@ def __init__( | |
| self.weight = quant_linear_module.qweight | ||
| init_lora_weights = kwargs.pop("init_lora_weights", True) | ||
| self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | ||
| self.set_adapter(adapter_name) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is not needed anymore?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| result = self.quant_linear_module(x) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,11 @@ | |
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
| from transformers.pytorch_utils import Conv1D | ||
|
|
||
| from peft.tuners.lora import LoraLayer | ||
| from peft.utils import transpose | ||
|
|
@@ -29,17 +30,25 @@ class AdaLoraLayer(LoraLayer): | |
| adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B") | ||
| # other_param_names is defined in LoraLayer | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_features: int, | ||
| out_features: int, | ||
| ): | ||
| super().__init__(in_features, out_features) | ||
| def __init__(self, base_layer: nn.Module) -> None: | ||
| super().__init__(base_layer) | ||
| self.lora_E = nn.ParameterDict({}) | ||
| self.lora_A = nn.ParameterDict({}) | ||
| self.lora_B = nn.ParameterDict({}) | ||
| self.ranknum = nn.ParameterDict({}) | ||
|
|
||
| base_layer = self.get_base_layer() | ||
| if isinstance(base_layer, nn.Linear): | ||
| in_features, out_features = base_layer.in_features, base_layer.out_features | ||
| elif isinstance(base_layer, Conv1D): | ||
| in_features, out_features = ( | ||
| base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unsupported layer type {type(base_layer)}") | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
|
|
||
| def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): | ||
| self.r[adapter_name] = r | ||
| self.lora_alpha[adapter_name] = lora_alpha | ||
|
|
@@ -63,7 +72,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig | |
| self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) | ||
| if init_lora_weights: | ||
| self.reset_lora_parameters(adapter_name) | ||
| self.to(self.weight.device) | ||
| self.to(self.get_base_layer().weight.device) | ||
| self.set_adapter(self.active_adapters) | ||
|
|
||
| def reset_lora_parameters(self, adapter_name): | ||
|
|
@@ -73,32 +82,31 @@ def reset_lora_parameters(self, adapter_name): | |
| nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) | ||
|
|
||
|
|
||
| class SVDLinear(nn.Linear, AdaLoraLayer): | ||
| class SVDLinear(nn.Module, AdaLoraLayer): | ||
| # SVD-based adaptation by a dense layer | ||
| def __init__( | ||
| self, | ||
| base_layer: nn.Module, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment above about BC, I would slowly deprecate them instead of changing the signature of these public classes |
||
| adapter_name: str, | ||
| in_features: int, | ||
| out_features: int, | ||
| r: int = 0, | ||
| lora_alpha: int = 1, | ||
| lora_dropout: float = 0.0, | ||
| fan_in_fan_out: bool = False, | ||
| init_lora_weights: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| init_lora_weights = kwargs.pop("init_lora_weights", True) | ||
| nn.Linear.__init__(self, in_features, out_features, **kwargs) | ||
| AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) | ||
| super().__init__() | ||
| AdaLoraLayer.__init__(self, base_layer) | ||
| # Freezing the pre-trained weight matrix | ||
| self.weight.requires_grad = False | ||
| self.get_base_layer().weight.requires_grad = False | ||
|
|
||
| self.fan_in_fan_out = fan_in_fan_out | ||
| if fan_in_fan_out: | ||
| self.weight.data = self.weight.data.T | ||
|
|
||
| nn.Linear.reset_parameters(self) | ||
| self._active_adapter = adapter_name | ||
| self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | ||
| self.set_adapter(adapter_name) | ||
|
|
||
| @property | ||
| def weight(self) -> torch.Tensor: | ||
pacman100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self.get_base_layer().weight | ||
|
|
||
| def merge(self, safe_merge: bool = False) -> None: | ||
| """ | ||
|
|
@@ -116,21 +124,22 @@ def merge(self, safe_merge: bool = False) -> None: | |
| f"You are now additionally merging {','.join(self.active_adapters)}." | ||
| ) | ||
| for active_adapter in self.active_adapters: | ||
| base_layer = self.get_base_layer() | ||
| if active_adapter in self.lora_A.keys(): | ||
| if safe_merge: | ||
| # Note that safe_merge will be slower than the normal merge | ||
| # because of the copy operation. | ||
| orig_weights = self.weight.data.clone() | ||
| orig_weights = base_layer.weight.data.clone() | ||
| orig_weights += self.get_delta_weight(active_adapter) | ||
|
|
||
| if not torch.isfinite(orig_weights).all(): | ||
| raise ValueError( | ||
| f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" | ||
| ) | ||
|
|
||
| self.weight.data = orig_weights | ||
| base_layer.weight.data = orig_weights | ||
| else: | ||
| self.weight.data += self.get_delta_weight(active_adapter) | ||
| base_layer.weight.data += self.get_delta_weight(active_adapter) | ||
| self.merged_adapters.append(active_adapter) | ||
|
|
||
| def unmerge(self) -> None: | ||
|
|
@@ -140,7 +149,7 @@ def unmerge(self) -> None: | |
| while len(self.merged_adapters) > 0: | ||
| active_adapter = self.merged_adapters.pop() | ||
| if active_adapter in self.lora_A.keys(): | ||
| self.weight.data -= self.get_delta_weight(active_adapter) | ||
| self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) | ||
|
|
||
| def get_delta_weight(self, adapter) -> torch.Tensor: | ||
| return ( | ||
|
|
@@ -149,19 +158,16 @@ def get_delta_weight(self, adapter) -> torch.Tensor: | |
| / (self.ranknum[adapter] + 1e-5) | ||
| ) | ||
|
|
||
| def _linear(self, input: torch.Tensor) -> torch.Tensor: | ||
| return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
| # TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct? | ||
| if self.disable_adapters: | ||
| if self.merged: | ||
| self.unmerge() | ||
| result = self._linear(x) | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| elif self.merged: | ||
| result = self._linear(x) | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| else: | ||
| result = self._linear(x) | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| for active_adapter in self.active_adapters: | ||
| if active_adapter not in self.lora_A.keys(): | ||
| continue | ||
|
|
@@ -176,8 +182,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
| return result | ||
|
|
||
| def __repr__(self) -> str: | ||
| rep = super().__repr__() | ||
| return "adalora." + rep | ||
|
Comment on lines
+180
to
+182
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you maybe add few comments here elaborating:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, yes, I could. The only issue is that we have this method now more than a dozen times, so I would have to add the same comment very often. If you still think it's good to have, or have a better idea, LMK and I'll change it! |
||
|
|
||
|
|
||
| class RankAllocator(object): | ||
| class RankAllocator: | ||
| """ | ||
| The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this has been removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not being used anywhere. Removing this is not related to this PR specifically.