diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index dd4c94b947..b357d47dc1 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -27,10 +27,3 @@ from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit - -# Mapping of tuners that support direct plugging -TUNERS_MAPPING = { - "LORA": LoraModel, - "IA3": IA3Model, - "ADALORA": AdaLoraModel, -} diff --git a/src/peft/tuners/adalora/bnb.py b/src/peft/tuners/adalora/bnb.py index 3ccfd91b2b..a37745569a 100644 --- a/src/peft/tuners/adalora/bnb.py +++ b/src/peft/tuners/adalora/bnb.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import bitsandbytes as bnb +from typing import Any + import torch from peft.import_utils import is_bnb_4bit_available, is_bnb_available @@ -23,38 +24,28 @@ if is_bnb_available(): - class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer): + class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer): # Low-rank matrix for SVD-based adaptation def __init__( self, - adapter_name, - in_features, - out_features, + base_layer: torch.nn.Module, + adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: - bnb.nn.Linear8bitLt.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), - threshold=kwargs.get("threshold", 0.0), - index=kwargs.get("index", None), - ) - AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) + super().__init__() + AdaLoraLayer.__init__(self, base_layer) # Freezing the pre-trained weight matrix - self.weight.requires_grad = False + self.get_base_layer().weight.requires_grad = False - init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: - result = super().forward(x) + # note: no check for self.merged because merging is not supported (yet) + result = self.base_layer(x) if self.disable_adapters: return result @@ -82,40 +73,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: result += output return result + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep + if is_bnb_4bit_available(): - class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer): + class SVDLinear4bit(torch.nn.Module, AdaLoraLayer): # Low-rank matrix for SVD-based adaptation def __init__( self, - adapter_name, - in_features, - out_features, + base_layer: torch.nn.Module, + adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: - bnb.nn.Linear4bit.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - compute_dtype=kwargs.get("compute_dtype", torch.float32), - compress_statistics=kwargs.get("compress_statistics", True), - quant_type=kwargs.get("quant_type", "nf4"), - ) - AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) + super().__init__() + AdaLoraLayer.__init__(self, base_layer) # Freezing the pre-trained weight matrix - self.weight.requires_grad = False + self.get_base_layer().weight.requires_grad = False - init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) - def forward(self, x: torch.Tensor) -> torch.Tensor: - result = super().forward(x) + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) + result = self.base_layer(x, *args, **kwargs) if self.disable_adapters: return result @@ -151,3 +137,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output * scaling / ranknum result += output return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep diff --git a/src/peft/tuners/adalora/gptq.py b/src/peft/tuners/adalora/gptq.py index 92de32ac15..1c14ea9c44 100644 --- a/src/peft/tuners/adalora/gptq.py +++ b/src/peft/tuners/adalora/gptq.py @@ -20,22 +20,21 @@ class SVDQuantLinear(torch.nn.Module, AdaLoraLayer): def __init__( self, + base_layer, adapter_name, - quant_linear_module, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: - torch.nn.Module.__init__(self) - AdaLoraLayer.__init__( - self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures - ) - self.quant_linear_module = quant_linear_module - self.weight = quant_linear_module.qweight - init_lora_weights = kwargs.pop("init_lora_weights", True) + super().__init__() + AdaLoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: result = self.quant_linear_module(x) @@ -67,3 +66,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.to(expected_dtype) result += output return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index d9fbf903c9..b4a98de039 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -14,10 +14,9 @@ # limitations under the License. import warnings -from typing import List, Optional +from typing import Any, List, Optional import torch -import torch.nn.functional as F from torch import nn from peft.tuners.lora import LoraLayer @@ -30,12 +29,8 @@ class AdaLoraLayer(LoraLayer): adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B") # other_param_names is defined in LoraLayer - def __init__( - self, - in_features: int, - out_features: int, - ): - super().__init__(in_features, out_features) + def __init__(self, base_layer: nn.Module) -> None: + super().__init__(base_layer) self.lora_E = nn.ParameterDict({}) self.lora_A = nn.ParameterDict({}) self.lora_B = nn.ParameterDict({}) @@ -64,7 +59,12 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) if init_lora_weights: self.reset_lora_parameters(adapter_name) - self.to(self.weight.device) + + if hasattr(self.get_base_layer(), "qweight"): + # QuantLinear + self.to(self.get_base_layer().qweight.device) + else: + self.to(self.get_base_layer().weight.device) self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name): @@ -74,32 +74,27 @@ def reset_lora_parameters(self, adapter_name): nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) -class SVDLinear(nn.Linear, AdaLoraLayer): +class SVDLinear(nn.Module, AdaLoraLayer): # SVD-based adaptation by a dense layer def __init__( self, + base_layer: nn.Module, adapter_name: str, - in_features: int, - out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, + init_lora_weights: bool = True, **kwargs, ) -> None: - init_lora_weights = kwargs.pop("init_lora_weights", True) - nn.Linear.__init__(self, in_features, out_features, **kwargs) - AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) + super().__init__() + AdaLoraLayer.__init__(self, base_layer) # Freezing the pre-trained weight matrix - self.weight.requires_grad = False + self.get_base_layer().weight.requires_grad = False self.fan_in_fan_out = fan_in_fan_out - if fan_in_fan_out: - self.weight.data = self.weight.data.T - - nn.Linear.reset_parameters(self) + self._active_adapter = adapter_name self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ @@ -119,15 +114,17 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"You are now additionally merging {','.join(self.active_adapters)}." ) + if adapter_names is None: adapter_names = self.active_adapters for active_adapter in adapter_names: + base_layer = self.get_base_layer() if active_adapter in self.lora_A.keys(): if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. - orig_weights = self.weight.data.clone() + orig_weights = base_layer.weight.data.clone() orig_weights += self.get_delta_weight(active_adapter) if not torch.isfinite(orig_weights).all(): @@ -135,9 +132,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = orig_weights + base_layer.weight.data = orig_weights else: - self.weight.data += self.get_delta_weight(active_adapter) + base_layer.weight.data += self.get_delta_weight(active_adapter) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -147,7 +144,7 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: return ( @@ -156,19 +153,16 @@ def get_delta_weight(self, adapter) -> torch.Tensor: / (self.ranknum[adapter] + 1e-5) ) - def _linear(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: # TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct? if self.disable_adapters: if self.merged: self.unmerge() - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) else: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -183,8 +177,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep + -class RankAllocator(object): +class RankAllocator: """ The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index a863acce31..71f2ed7579 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -20,6 +20,7 @@ from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.lora import LoraConfig, LoraModel +from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import ( TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, _freeze_adapter, @@ -67,6 +68,8 @@ class AdaLoraModel(LoraModel): - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. """ + # Note: don't redefine prefix here, it should be inherited from LoraModel + def __init__(self, model, config, adapter_name): super().__init__(model, config, adapter_name) @@ -121,7 +124,7 @@ def _create_and_replace( loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False) if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available(): raise ImportError( - "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " + "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. " "You can install it with `pip install bitsandbytes`." ) kwargs = { @@ -138,7 +141,7 @@ def _create_and_replace( if quantization_config is not None: kwargs["gptq_quantization_config"] = quantization_config - # If it is not a LoraLayer, create a new module, else update it with new adapters + # If it is not an AdaLoraLayer, create a new module, else update it with new adapters if not isinstance(target, AdaLoraLayer): new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) if adapter_name != self.active_adapter: @@ -159,11 +162,15 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): gptq_quantization_config = kwargs.get("gptq_quantization_config", None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) - bias = target.bias is not None loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) - if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): kwargs.update( { "has_fp16_weights": target.state.has_fp16_weights, @@ -172,8 +179,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "index": target.index, } ) - new_module = SVDLinear8bitLt(adapter_name, target.in_features, target.out_features, bias=bias, **kwargs) - elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): + new_module = SVDLinear8bitLt(target, adapter_name, **kwargs) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() fourbit_kwargs.update( { @@ -182,25 +189,18 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "quant_type": target.weight.quant_type, } ) - new_module = SVDLinear4bit( - adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs - ) + new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): - new_module = SVDQuantLinear(adapter_name, target, **kwargs) - target.weight = target.qweight + new_module = SVDQuantLinear(target, adapter_name, **kwargs) else: - if isinstance(target, torch.nn.Linear): - in_features, out_features = target.in_features, target.out_features + if isinstance(target_base_layer, torch.nn.Linear): if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape - ) + elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to False but the target module is `Conv1D`. " @@ -212,7 +212,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): f"Target module {target} is not supported. " f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." ) - new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs) + new_module = SVDLinear(target, adapter_name, **kwargs) return new_module diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 2aa37c1d5c..2666b3ab6e 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import bitsandbytes as bnb +from typing import Any + import torch from peft.import_utils import is_bnb_4bit_available, is_bnb_available @@ -23,39 +24,27 @@ if is_bnb_available(): - class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer): + class Linear8bitLt(torch.nn.Module, IA3Layer): # (IA)^3 implemented in a dense layer def __init__( self, - adapter_name, - in_features, - out_features, - is_feedforward, + base_layer: torch.nn.Module, + adapter_name: str, + is_feedforward: bool, + init_ia3_weights: bool = True, **kwargs, ) -> None: - bnb.nn.Linear8bitLt.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), - threshold=kwargs.get("threshold", 0.0), - index=kwargs.get("index", None), - ) - IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) - self.is_feedforward = is_feedforward + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - init_ia3_weights = kwargs.pop("init_ia3_weights", True) + self.get_base_layer().weight.requires_grad = False self.update_layer(adapter_name, init_ia3_weights) - self.set_adapter(adapter_name) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) if self.disable_adapters: - return super().forward(x) + return self.base_layer(x) ia3_scaling = 1 for active_adapter in self.active_adapters: @@ -67,10 +56,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if requires_conversion: x = x.float() if self.is_feedforward: - result = super().forward(x * ia3_scaling) + result = self.base_layer(x * ia3_scaling) expected_dtype = result.dtype else: - result = super().forward(x) + result = self.base_layer(x) expected_dtype = result.dtype result = result * ia3_scaling @@ -79,41 +68,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result + def __repr__(self) -> str: + rep = super().__repr__() + return "ia3." + rep + if is_bnb_4bit_available(): - class Linear4bit(bnb.nn.Linear4bit, IA3Layer): + class Linear4bit(torch.nn.Module, IA3Layer): # IA3 implemented in a dense layer def __init__( self, - adapter_name, - in_features, - out_features, - is_feedforward, + base_layer: torch.nn.Module, + adapter_name: str, + is_feedforward: bool, + init_ia3_weights: bool = True, **kwargs, ) -> None: - bnb.nn.Linear4bit.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - compute_dtype=kwargs.get("compute_dtype", torch.float32), - compress_statistics=kwargs.get("compress_statistics", True), - quant_type=kwargs.get("quant_type", "nf4"), - ) - IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) - self.is_feedforward = is_feedforward + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - init_ia3_weights = kwargs.pop("init_ia3_weights", True) + self.get_base_layer().weight.requires_grad = False self.update_layer(adapter_name, init_ia3_weights) - self.set_adapter(adapter_name) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) if self.disable_adapters: - return super().forward(x) + return self.base_layer(x) ia3_scaling = 1 for active_adapter in self.active_adapters: @@ -125,10 +107,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if requires_conversion: x = x.float() if self.is_feedforward: - result = super().forward(x * ia3_scaling) + result = self.base_layer(x * ia3_scaling) expected_dtype = result.dtype else: - result = super().forward(x) + result = self.base_layer(x) expected_dtype = result.dtype result = result * ia3_scaling @@ -140,3 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: result = result.to(expected_dtype) return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "ia3." + rep diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 50696a0e08..45ef388399 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -14,11 +14,11 @@ # limitations under the License. import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional import torch import torch.nn as nn -import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import transpose @@ -30,20 +30,30 @@ class IA3Layer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_layer_names = ("scaling",) - def __init__( - self, - in_features: int, - out_features: int, - is_feedforward: bool, - ): + def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None: + self.base_layer = base_layer self.scaling = {} self.ia3_l = nn.ParameterDict({}) # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.is_feedforward = is_feedforward + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") self.in_features = in_features self.out_features = out_features - self.is_feedforward = is_feedforward def update_layer(self, adapter_name, init_ia3_weights): # Actual trainable parameters @@ -54,7 +64,7 @@ def update_layer(self, adapter_name, init_ia3_weights): self.ia3_l[adapter_name] = nn.Parameter(weight) if init_ia3_weights: self.reset_ia3_parameters(adapter_name) - self.to(self.weight.device) + self.to(self.get_base_layer().weight.device) self.set_adapter(self.active_adapters) def reset_ia3_parameters(self, adapter_name): @@ -63,35 +73,24 @@ def reset_ia3_parameters(self, adapter_name): nn.init.constant_(self.ia3_l[adapter_name], 1.0) -class Linear(nn.Linear, IA3Layer): +class Linear(nn.Module, IA3Layer): # (IA)^3 implemented in a dense layer def __init__( self, + base_layer: nn.Module, adapter_name: str, - in_features: int, - out_features: int, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later + init_ia3_weights: bool = True, # whether to initialize IA3 weights **kwargs, ) -> None: - init_ia3_weights = kwargs.pop("init_ia3_weights", True) - - nn.Linear.__init__(self, in_features, out_features, **kwargs) - IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) - self.is_feedforward = is_feedforward - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) self.fan_in_fan_out = fan_in_fan_out - if fan_in_fan_out: - self.weight.data = self.weight.data.T - self.is_target_conv_1d_layer = is_target_conv_1d_layer - - nn.Linear.reset_parameters(self) + self._active_adapter = adapter_name self.update_layer(adapter_name, init_ia3_weights) - self.set_adapter(adapter_name) def update_layer(self, adapter_name, init_ia3_weights): # Actual trainable parameters @@ -102,7 +101,7 @@ def update_layer(self, adapter_name, init_ia3_weights): self.ia3_l[adapter_name] = nn.Parameter(weight) if init_ia3_weights: self.reset_ia3_parameters(adapter_name) - self.to(self.weight.device) + self.to(self.get_base_layer().weight.device) self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -129,24 +128,23 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() + ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) if safe_merge: - orig_weights = transpose(self.weight, self.fan_in_fan_out).clone() - orig_weights = torch.mul(orig_weights.data, self.ia3_l[active_adapter].data) + orig_weights = base_layer.weight.data + orig_weights = torch.mul(orig_weights, ia3_l) if not torch.isfinite(orig_weights).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = orig_weights - self.weight = transpose(self.weight, self.fan_in_fan_out) + base_layer.weight.data = orig_weights else: - self.weight = transpose(self.weight, self.fan_in_fan_out) - self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data) - self.weight = transpose(self.weight, self.fan_in_fan_out) + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l) - if not self.is_feedforward and (self.bias is not None): - scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) - self.bias.data = torch.mul(self.bias.data, scaling.data) + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) self.merged_adapters.append(active_adapter) @@ -159,27 +157,24 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.ia3_l.keys(): - self.weight = transpose(self.weight, self.fan_in_fan_out) - # divide by (IA)^3 vector. Add tolerace to avoid division by zero - self.weight.data = torch.div(self.weight.data, self.ia3_l[active_adapter].data + 1e-8) - self.weight = transpose(self.weight, self.fan_in_fan_out) + base_layer = self.get_base_layer() + # Add tolerace to avoid division by zero + ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8 + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l) - if not self.is_feedforward and (self.bias is not None): - scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) - self.bias.data = torch.div(self.bias.data, scaling.data + 1e-8) + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8) - def _linear(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: dtype = previous_dtype = x.dtype if self.disable_adapters: if self.merged: self.unmerge() - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) else: ia3_scaling = 1 for active_adapter in self.active_adapters: @@ -190,46 +185,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_feedforward: x = x.to(dtype) - # TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype # e.g. bf16 vs fp32. Is that okay? - interm = (x * ia3_scaling).to(self.weight.dtype) - result = self._linear(interm) + interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) + result = self.base_layer(interm, *args, **kwargs) else: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) result = result.to(dtype) * ia3_scaling result = result.to(previous_dtype) return result -class Conv2d(nn.Conv2d, IA3Layer): +class Conv2d(nn.Module, IA3Layer): def __init__( self, + base_layer: nn.Module, adapter_name: str, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - stride: Union[int, Tuple[int]] = 1, - padding: Union[int, Tuple[int]] = 0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer + init_ia3_weights: bool = True, **kwargs, ) -> None: - init_ia3_weights = kwargs.pop("init_ia3_weights", True) - - nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - IA3Layer.__init__(self, in_features=in_channels, out_features=out_channels, is_feedforward=is_feedforward) - self.is_feedforward = is_feedforward - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) self.fan_in_fan_out = fan_in_fan_out - if fan_in_fan_out: - self.weight.data = self.weight.data.T + self._active_adapter = adapter_name - nn.Conv2d.reset_parameters(self) self.update_layer(adapter_name, init_ia3_weights) - self.set_adapter(adapter_name) def update_layer(self, adapter_name, init_ia3_weights): # Actual trainable parameters @@ -240,7 +223,7 @@ def update_layer(self, adapter_name, init_ia3_weights): self.ia3_l[adapter_name] = nn.Parameter(weight) if init_ia3_weights: self.reset_ia3_parameters(adapter_name) - self.to(self.weight.device) + self.to(self.get_base_layer().weight.device) self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -267,25 +250,26 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) if safe_merge: - output_weight = torch.mul(self.weight.data, ia3_scaling).clone() + output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone() if not torch.isfinite(output_weight).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = output_weight + base_layer.weight.data = output_weight else: - self.weight.data = torch.mul(self.weight.data, ia3_scaling) + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling) - if not self.is_feedforward and (self.bias is not None): - scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) - self.bias.data = torch.mul(self.bias.data, scaling.data) + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) self.merged_adapters.append(active_adapter) @@ -298,36 +282,26 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() # divide by (IA)^3 vector. Add tolerace to avoid division by zero ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) - self.weight.data = torch.div(self.weight.data, ia3_scaling + 1e-8) - - if not self.is_feedforward and (self.bias is not None): - scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) - self.bias.data = torch.mul(self.bias.data, scaling.data) - - def _conv2d(self, input: torch.Tensor) -> torch.Tensor: - return F.conv2d( - input, - self.weight, - bias=self.bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8) + + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: previous_dtype = x.dtype if self.disable_adapters: if self.merged: self.unmerge() - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) else: ia3_scaling = 1 for active_adapter in self.active_adapters: @@ -338,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_feedforward: x = x.to(dtype) - # TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype # e.g. bf16 vs fp32. Is that okay? - interm = (x * ia3_scaling).to(self.weight.dtype) - result = self._conv2d(interm) + interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) + result = self.base_layer(interm, *args, **kwargs) else: - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) result = result.to(dtype) * ia3_scaling result = result.to(previous_dtype) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 29802359f7..7b2f9d19d9 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -23,7 +23,7 @@ from transformers.pytorch_utils import Conv1D from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from peft.tuners.tuners_utils import BaseTuner, check_target_module_exists +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists from peft.utils import ( TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, @@ -83,12 +83,16 @@ def __init__(self, model, config, adapter_name): @staticmethod def _create_new_module(ia3_config, adapter_name, target, **kwargs): - bias = hasattr(target, "bias") and target.bias is not None loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) is_feedforward = kwargs.pop("is_feedforward", False) - if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() eightbit_kwargs.update( { @@ -98,15 +102,8 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs): "index": target.index, } ) - new_module = Linear8bitLt( - adapter_name, - target.in_features, - target.out_features, - is_feedforward, - bias=bias, - **eightbit_kwargs, - ) - elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit): + new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs) + elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() fourbit_kwargs.update( { @@ -115,56 +112,31 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs): "quant_type": target.weight.quant_type, } ) - new_module = Linear4bit( - adapter_name, - target.in_features, - target.out_features, - is_feedforward, - bias=bias, - **fourbit_kwargs, - ) + new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs) elif isinstance(target, torch.nn.Conv2d): - out_channels, in_channels = target.weight.size()[:2] - kernel_size = target.weight.size()[2:] - stride = target.stride - padding = target.padding - new_module = Conv2d( - adapter_name=adapter_name, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - is_feedforward=is_feedforward, - **kwargs, - ) - else: - if isinstance(target, torch.nn.Linear): - in_features, out_features = target.in_features, target.out_features - if kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) - kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." ) - kwargs["is_target_conv_1d_layer"] = True # useful for unloading later - if not kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." - ) - kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True - else: - raise ValueError( - f"Target module {target} is not supported. " - f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported." + kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False + new_module = Linear(target, adapter_name, is_feedforward=is_feedforward, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." ) + kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True new_module = Linear( - adapter_name, in_features, out_features, is_feedforward=is_feedforward, bias=bias, **kwargs + target, adapter_name, is_feedforward=is_feedforward, is_target_conv_1d_layer=True, **kwargs + ) + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported." ) return new_module @@ -201,21 +173,16 @@ def _create_and_replace( "is_feedforward": is_feedforward, } - if isinstance(target, IA3Layer): - if target.is_feedforward != is_feedforward: - raise ValueError( - "New adapter should have the same value for `is_feedforward` as previously added adapter." - ) - if isinstance(target, torch.nn.Conv2d): - target.update_layer( - adapter_name, - ia3_config.init_ia3_weights, - ) - else: # Linear - target.update_layer( - adapter_name, - ia3_config.init_ia3_weights, - ) + if isinstance(target, Conv2d): + target.update_layer( + adapter_name, + ia3_config.init_ia3_weights, + ) + elif isinstance(target, Linear): + target.update_layer( + adapter_name, + ia3_config.init_ia3_weights, + ) else: new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) if adapter_name != self.active_adapter: @@ -238,11 +205,22 @@ def _check_target_module_feedforward(ia3_config, key) -> bool: @staticmethod def _replace_module(parent, child_name, new_module, child): setattr(parent, child_name, new_module) - new_module.weight = child.weight - if child.bias is not None: - new_module.bias = child.bias + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + # layers with base_layer don't need the weight to be copied, as they have a reference already + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + if getattr(child, "state", None) is not None: - new_module.state = child.state + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state new_module.to(child.weight.device) # dispatch to correct device @@ -298,7 +276,9 @@ def _prepare_adapter_config(self, peft_config, model_config): ] return peft_config - def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None): + def _unload_and_optionally_merge( + self, merge: bool = True, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ): r""" This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model as a standalone model. @@ -325,31 +305,46 @@ def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[Lis except AttributeError: continue - # save any additional trainable modules part of `modules_to_save` - if isinstance(target, ModulesToSaveWrapper): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` setattr(parent, target_name, target.modules_to_save[target.active_adapter]) - continue - if not isinstance(target, IA3Layer): - continue + return self.model - if isinstance(target, torch.nn.Conv2d): - new_module = torch.nn.Conv2d( - target.in_channels, - target.out_channels, - kernel_size=target.kernel_size, - stride=target.stride, - padding=target.padding, - dilation=target.dilation, - ) - else: - bias = target.bias is not None - if getattr(target, "is_target_conv_1d_layer", False): - new_module = Conv1D(target.out_features, target.in_features) - else: - new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) + def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None): + r""" + This method merges the IA³ layers into the base model. This is needed if someone wants to use the base model as + a standalone model. - target.merge(safe_merge=safe_merge, adapter_names=adapter_names) - self._replace_module(parent, target_name, new_module, target) + Args: + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. - return self.model + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names) + + def unload(self): + """ + Gets back the base model by removing all the IA³ modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/tuners/loha/layer.py b/src/peft/tuners/loha/layer.py index 2a8a205b02..4733336419 100644 --- a/src/peft/tuners/loha/layer.py +++ b/src/peft/tuners/loha/layer.py @@ -14,7 +14,7 @@ # limitations under the License. import math -from typing import Optional, Set, Tuple, Union +from typing import Any, Set, Tuple import torch import torch.nn as nn @@ -23,14 +23,14 @@ from peft.tuners.lycoris_utils import LycorisLayer -class LoHaLayer(LycorisLayer, nn.Module): +class LoHaLayer(nn.Module, LycorisLayer): # All names of layers that may contain adapter weights adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2") # other_param_names is defined on parent class - def __init__(self): - LycorisLayer.__init__(self) - super(nn.Module, self).__init__() + def __init__(self, base_layer: nn.Module): + super().__init__() + LycorisLayer.__init__(self, base_layer) # LoHa info self.hada_w1_a = nn.ParameterDict({}) @@ -76,6 +76,21 @@ def reset_adapter_parameters(self, adapter_name: str): nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5)) + def reset_adapter_parameters_random(self, adapter_name: str): + # Original implementation performs initialization with normal distribution + # https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158 + + # FedPara paper proposes to perform He initialization, let's stick with it + # It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization + if adapter_name in self.hada_w1_a.keys(): + nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w2_b[adapter_name], a=math.sqrt(5)) + if adapter_name in self.hada_t1.keys(): + nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5)) + def update_layer( self, adapter_name: str, @@ -107,16 +122,20 @@ def update_layer( self.module_dropout[adapter_name] = module_dropout # Determine shape of LoHa weights - if isinstance(self, nn.Linear): - shape = tuple(self.weight.shape) - elif isinstance(self, nn.Conv2d): - use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1) + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + shape = tuple(base_layer.weight.shape) + elif isinstance(base_layer, nn.Conv2d): + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) if use_effective_conv2d: - shape = (self.out_channels, self.in_channels, *self.kernel_size) + shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size) else: - shape = (self.out_channels, self.in_channels * self.kernel_size[0] * self.kernel_size[1]) + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ) else: - raise TypeError(f"LoHa is not implemented for {type(self).__name__} layer") + raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}") # Create weights with provided shape self.create_adapter_parameters(adapter_name, r, shape) @@ -124,9 +143,11 @@ def update_layer( # Initialize weights if init_weights: self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) # Move new weights to device - weight = getattr(self, "weight", None) + weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: # the layer is already completely initialized, this is an update if weight.dtype.is_floating_point or weight.dtype.is_complex: @@ -156,7 +177,8 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: scale=torch.tensor(self.scaling[adapter_name]), ) - weight = weight.reshape(self.weight.shape) + base_layer = self.get_base_layer() + weight = weight.reshape(base_layer.weight.shape) # Perform rank dropout during training - drop rows of addition weights rank_dropout = self.rank_dropout[adapter_name] @@ -171,96 +193,107 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: return weight + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) + + result = result.to(previous_dtype) + return result -class Linear(LoHaLayer, nn.Linear): + +class Linear(LoHaLayer): """LoHa implemented in Linear layer""" def __init__( self, - in_features: int, - out_features: int, - bias: bool = True, - device: Optional[Union[str, torch.device]] = None, - dtype: Optional[torch.dtype] = None, + base_layer: nn.Module, adapter_name: str = "default", r: int = 0, alpha: float = 0.0, rank_dropout: float = 0.0, module_dropout: float = 0.0, + init_weights: bool = True, **kwargs, ): - init_weights = kwargs.pop("init_weights", True) - self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype) - - LoHaLayer.__init__(self) + super().__init__(base_layer) # Create adapter and set it active + self._active_adapter = adapter_name self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) - self.set_adapter(adapter_name) - def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - return F.linear(input, weight, bias=self.bias) + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + return F.linear(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "loha." + rep -class Conv2d(LoHaLayer, nn.Conv2d): +class Conv2d(LoHaLayer): """LoHa implemented in Conv2d layer""" def __init__( self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - stride: Union[int, Tuple[int]] = 1, - padding: Union[int, Tuple[int]] = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device: Optional[Union[str, torch.device]] = None, - dtype: Optional[torch.dtype] = None, + base_layer: nn.Module, adapter_name: str = "default", r: int = 0, alpha: float = 0.0, rank_dropout: float = 0.0, module_dropout: float = 0.0, use_effective_conv2d: bool = False, + init_weights: bool = True, **kwargs, ): - init_weights = kwargs.pop("init_weights", True) - self._init_empty_weights( - nn.Conv2d, - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - LoHaLayer.__init__(self) + super().__init__(base_layer) # Create adapter and set it active + self._active_adapter = adapter_name self.update_layer( adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs ) - self.set_adapter(adapter_name) - def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() return F.conv2d( input, - weight, - bias=self.bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, ) + def __repr__(self) -> str: + rep = super().__repr__() + return "loha." + rep + # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9 diff --git a/src/peft/tuners/loha/model.py b/src/peft/tuners/loha/model.py index 92d5b887ef..e641fdbac7 100644 --- a/src/peft/tuners/loha/model.py +++ b/src/peft/tuners/loha/model.py @@ -13,11 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Type +import re +from itertools import chain +from typing import Dict, Type, Union import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner -from ..lycoris_utils import LycorisTuner from .layer import Conv2d, Linear, LoHaLayer @@ -82,3 +86,31 @@ class LoHaModel(LycorisTuner): torch.nn.Conv2d: Conv2d, torch.nn.Linear: Linear, } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[LoHaLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + **optional_kwargs, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha) + + if isinstance(target, LoHaLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/src/peft/tuners/lokr/layer.py b/src/peft/tuners/lokr/layer.py index 97f3afb6fd..c733f4f4a5 100644 --- a/src/peft/tuners/lokr/layer.py +++ b/src/peft/tuners/lokr/layer.py @@ -14,7 +14,7 @@ # limitations under the License. import math -from typing import Optional, Set, Tuple, Union +from typing import Any, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -23,7 +23,7 @@ from peft.tuners.lycoris_utils import LycorisLayer -class LoKrLayer(LycorisLayer, nn.Module): +class LoKrLayer(nn.Module, LycorisLayer): # All names of layers that may contain adapter weights adapter_layer_names = ( "lokr_w1", @@ -36,9 +36,9 @@ class LoKrLayer(LycorisLayer, nn.Module): ) # other_param_names is defined on parent class - def __init__(self): - LycorisLayer.__init__(self) - super(nn.Module, self).__init__() + def __init__(self, base_layer: nn.Module) -> None: + super().__init__() + LycorisLayer.__init__(self, base_layer) # LoKr info self.lokr_w1 = nn.ParameterDict({}) @@ -111,6 +111,22 @@ def reset_adapter_parameters(self, adapter_name: str): if adapter_name in self.lokr_t2: nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) + def reset_adapter_parameters_random(self, adapter_name: str): + if adapter_name in self.lokr_w1: + nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5)) + else: + nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_w2: + nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) + else: + nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_t2: + nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) + def update_layer( self, adapter_name: str, @@ -143,10 +159,11 @@ def update_layer( self.scaling[adapter_name] = alpha / r self.rank_dropout[adapter_name] = rank_dropout self.module_dropout[adapter_name] = module_dropout + base_layer = self.get_base_layer() # Determine shape of LoKr weights - if isinstance(self, nn.Linear): - in_dim, out_dim = self.in_features, self.out_features + if isinstance(base_layer, nn.Linear): + in_dim, out_dim = base_layer.in_features, base_layer.out_features in_m, in_n = factorization(in_dim, decompose_factor) out_l, out_k = factorization(out_dim, decompose_factor) @@ -155,9 +172,9 @@ def update_layer( use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2) use_effective_conv2d = False - elif isinstance(self, nn.Conv2d): - in_dim, out_dim = self.in_channels, self.out_channels - k_size = self.kernel_size + elif isinstance(base_layer, nn.Conv2d): + in_dim, out_dim = base_layer.in_channels, base_layer.out_channels + k_size = base_layer.kernel_size in_m, in_n = factorization(in_dim, decompose_factor) out_l, out_k = factorization(out_dim, decompose_factor) @@ -165,9 +182,9 @@ def update_layer( use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 - use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1) + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) else: - raise TypeError(f"LoKr is not implemented for {type(self).__name__} layer") + raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") # Create weights with provided shape self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d) @@ -175,9 +192,11 @@ def update_layer( # Initialize weights if init_weights: self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) # Move new weights to device - weight = getattr(self, "weight", None) + weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: # the layer is already completely initialized, this is an update if weight.dtype.is_floating_point or weight.dtype.is_complex: @@ -202,7 +221,7 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: # Make weights with Kronecker product weight = make_kron(w1, w2) - weight = weight.reshape(self.weight.shape) + weight = weight.reshape(self.get_base_layer().weight.shape) # Perform rank dropout during training - drop rows of addition weights rank_dropout = self.rank_dropout[adapter_name] @@ -214,15 +233,39 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: return weight + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) + + result = result.to(previous_dtype) + return result -class Linear(LoKrLayer, nn.Linear): + +class Linear(LoKrLayer): """LoKr implemented in Linear layer""" def __init__( self, - in_features: int, - out_features: int, - bias: bool = True, + base_layer: nn.Module, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, adapter_name: str = "default", @@ -230,35 +273,33 @@ def __init__( alpha: float = 0.0, rank_dropout: float = 0.0, module_dropout: float = 0.0, + init_weights: bool = True, **kwargs, ): - init_weights = kwargs.pop("init_weights", True) - self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype) - - LoKrLayer.__init__(self) + super().__init__(base_layer) # Create adapter and set it active + self._active_adapter = adapter_name self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) - self.set_adapter(adapter_name) - def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - return F.linear(input, weight, bias=self.bias) + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + return F.linear(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "lokr." + rep -class Conv2d(LoKrLayer, nn.Conv2d): +class Conv2d(LoKrLayer): """LoKr implemented in Conv2d layer""" def __init__( self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - stride: Union[int, Tuple[int]] = 1, - padding: Union[int, Tuple[int]] = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", + base_layer: nn.Module, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, adapter_name: str = "default", @@ -267,43 +308,36 @@ def __init__( rank_dropout: float = 0.0, module_dropout: float = 0.0, use_effective_conv2d: bool = False, + init_weights: bool = True, **kwargs, ): - init_weights = kwargs.pop("init_weights", True) - self._init_empty_weights( - nn.Conv2d, - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - LoKrLayer.__init__(self) + super().__init__(base_layer) # Create adapter and set it active + self._active_adapter = adapter_name self.update_layer( adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs ) - self.set_adapter(adapter_name) - def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() return F.conv2d( input, - weight, - bias=self.bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, ) + def __repr__(self) -> str: + rep = super().__repr__() + return "lokr." + rep + # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 diff --git a/src/peft/tuners/lokr/model.py b/src/peft/tuners/lokr/model.py index e08b7a7c48..61535b28b3 100644 --- a/src/peft/tuners/lokr/model.py +++ b/src/peft/tuners/lokr/model.py @@ -13,11 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Type +import re +from itertools import chain +from typing import Dict, Type, Union import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner -from ..lycoris_utils import LycorisTuner from .layer import Conv2d, Linear, LoKrLayer @@ -83,3 +87,31 @@ class LoKrModel(LycorisTuner): torch.nn.Conv2d: Conv2d, torch.nn.Linear: Linear, } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[LoKrLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + **optional_kwargs, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha) + + if isinstance(target, LoKrLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 4bd8151ed3..1c42a9e8e3 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -30,20 +30,18 @@ class Linear8bitLt(torch.nn.Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, - adapter_name, - base_layer, + base_layer: torch.nn.Module, + adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: super().__init__() - LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) - self.base_layer = base_layer + LoraLayer.__init__(self, base_layer) - init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False): """ @@ -69,8 +67,8 @@ def merge(self, safe_merge: bool = False): ) lora_data = self.get_delta_weight(active_adapter) - weight = self.base_layer.weight - state = self.base_layer.state + weight = self.get_base_layer().weight + state = self.get_base_layer().state if state.SCB is None: state.SCB = weight.SCB @@ -90,7 +88,7 @@ def merge(self, safe_merge: bool = False): f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight = bnb.nn.Int8Params( + self.get_base_layer().weight = bnb.nn.Int8Params( w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights ).to(weight.device) state.reset_grads() @@ -110,8 +108,8 @@ def unmerge(self): ) lora_data = self.get_delta_weight(active_adapter) - weight = self.base_layer.weight - state = self.base_layer.state + weight = self.get_base_layer().weight + state = self.get_base_layer().state if state.SCB is None: state.SCB = weight.SCB im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) @@ -124,7 +122,7 @@ def unmerge(self): output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data - self.base_layer.weight = bnb.nn.Int8Params( + self.get_base_layer().weight = bnb.nn.Int8Params( w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights ).to(weight.device) state.reset_grads() @@ -169,6 +167,10 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return result + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + if is_bnb_4bit_available(): @@ -176,20 +178,18 @@ class Linear4bit(torch.nn.Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, - adapter_name, - base_layer, + base_layer: torch.nn.Module, + adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: super().__init__() - LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) - self.base_layer = base_layer + LoraLayer.__init__(self, base_layer) - init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False): """ @@ -214,7 +214,7 @@ def merge(self, safe_merge: bool = False): "Merge lora module to 4-bit linear may get different generations due to rounding errors." ) # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 - weight = self.base_layer.weight + weight = self.get_base_layer().weight kwargs = weight.__dict__ lora_data = self.get_delta_weight(active_adapter) @@ -224,7 +224,7 @@ def merge(self, safe_merge: bool = False): f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( weight.device ) self.merged_adapters.append(active_adapter) @@ -241,11 +241,11 @@ def unmerge(self): warnings.warn( "Unmerge lora module to 4-bit linear may get different generations due to rounding errors." ) - weight = self.base_layer.weight + weight = self.get_base_layer().weight kwargs = weight.__dict__ lora_data = self.get_delta_weight(active_adapter) w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data - self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( weight.device ) @@ -262,11 +262,11 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: self.unmerge() - result = self.base_layer.forward(x, *args, **kwargs) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self.base_layer.forward(x, *args, **kwargs) + result = self.base_layer(x, *args, **kwargs) else: - result = self.base_layer.forward(x, *args, **kwargs) + result = self.base_layer(x, *args, **kwargs) # As per Tim Dettmers, for 4bit, we need to defensively clone here. # The reason is that in some cases, an error can occur that backprop # does not work on a manipulated view. This issue may be solved with @@ -294,3 +294,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result += output return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 1505045a3e..75c853184c 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -21,22 +21,21 @@ class QuantLinear(torch.nn.Module, LoraLayer): def __init__( self, - adapter_name, - quant_linear_module, + base_layer, + adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ): - torch.nn.Module.__init__(self) - LoraLayer.__init__( - self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures - ) - self.quant_linear_module = quant_linear_module - self.weight = quant_linear_module.qweight - init_lora_weights = kwargs.pop("init_lora_weights", True) + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported @@ -65,6 +64,10 @@ def forward(self, x: torch.Tensor): result += output return result + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 # def reset_lora_parameters(self, adapter_name): # if adapter_name in self.lora_A.keys(): diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index e2ced1eee9..c263053183 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -15,11 +15,12 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional import torch import torch.nn as nn import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils.other import transpose @@ -31,7 +32,8 @@ class LoraLayer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") - def __init__(self, in_features: int, out_features: int, **kwargs): + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer self.r = {} self.lora_alpha = {} self.scaling = {} @@ -44,21 +46,26 @@ def __init__(self, in_features: int, out_features: int, **kwargs): # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + self.in_features = in_features self.out_features = out_features - self.kwargs = kwargs - - def _init_empty_weights(self, cls, *args, **kwargs) -> None: - # A helper method that allows to initialize the layer of the given class without spending time to initialize the - # model weights. The implementation is inspired by - # https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html but this function cannot be used - # directly. - # Instead of this approach, it would be possible to bypass the __init__ of the class but that runs the risk of - # omitting important logic inside that __init__. - kwargs = kwargs.copy() - final_device = kwargs.pop("device", "cpu") - cls.__init__(self, *args, device="meta", **kwargs) - self.to_empty(device=final_device) def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): if r <= 0: @@ -79,7 +86,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if init_lora_weights: self.reset_lora_parameters(adapter_name) - weight = getattr(self, "weight", None) + weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: # the layer is already completely initialized, this is an update if weight.dtype.is_floating_point or weight.dtype.is_complex: @@ -100,20 +107,22 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo self.lora_dropout[adapter_name] = lora_dropout_layer # Actual trainable parameters + base_layer = self.get_base_layer() if r > 0: - kernel_size = self.kwargs["kernel_size"] - stride = self.kwargs["stride"] - padding = self.kwargs["padding"] + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: self.reset_lora_parameters(adapter_name) - weight = getattr(self, "weight", None) + weight = getattr(base_layer, "weight", None) if weight is not None: # the layer is already completely initialized, this is an update - self.to(self.weight.device, dtype=weight.dtype) + self.to(base_layer.weight.device, dtype=weight.dtype) + self.set_adapter(self.active_adapters) def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): if r <= 0: @@ -136,10 +145,12 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init if init_lora_weights: self.reset_lora_parameters(adapter_name) - weight = getattr(self, "weight", None) + base_layer = self.get_base_layer() + weight = getattr(base_layer, "weight", None) if weight is not None: # the layer is already completely initialized, this is an update - self.to(self.weight.device, dtype=weight.dtype) + self.to(base_layer.weight.device, dtype=weight.dtype) + self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): @@ -188,35 +199,27 @@ def unscale_layer(self, scale=None) -> None: # ------------------------------------------------------------------------------------------ -class Linear(nn.Linear, LoraLayer): +class Linear(nn.Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, + base_layer, adapter_name: str, - in_features: int, - out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_target_conv_1d_layer: bool = False, + init_lora_weights: bool = True, **kwargs, ) -> None: - init_lora_weights = kwargs.pop("init_lora_weights", True) - # this gets the init from nn.Linear's super perspective, i.e. - # nn.Module.__init__, which should always be called - super(nn.Linear, self).__init__() - # Note that we don't use self._init_empty_weights() for Linear because it is a bit slower and the benefit of - # added robustness is not big enough for Linear. - - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) - # Freezing the pre-trained weight matrix - + super().__init__() + LoraLayer.__init__(self, base_layer) self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.is_target_conv_1d_layer = is_target_conv_1d_layer - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ @@ -242,10 +245,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. - orig_weights = self.weight.data.clone() + orig_weights = base_layer.weight.data.clone() orig_weights += self.get_delta_weight(active_adapter) if not torch.isfinite(orig_weights).all(): @@ -253,9 +257,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = orig_weights + base_layer.weight.data = orig_weights else: - self.weight.data += self.get_delta_weight(active_adapter) + base_layer.weight.data += self.get_delta_weight(active_adapter) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -265,7 +269,7 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -301,20 +305,17 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor - def _linear(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: previous_dtype = x.dtype if self.disable_adapters: if self.merged: self.unmerge() - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) else: - result = self._linear(x) + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -328,24 +329,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: result = result.to(previous_dtype) return result + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep -class Embedding(nn.Embedding, LoraLayer): + +class Embedding(nn.Module, LoraLayer): # LoRA implemented in a Embedding layer def __init__( self, + base_layer: nn.Module, adapter_name: str, - num_embeddings: int, - embedding_dim: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: - init_lora_weights = kwargs.pop("init_lora_weights", True) - self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs) - LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim) + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ @@ -371,10 +376,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_embedding_A.keys(): + base_layer = self.get_base_layer() if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. - orig_weights = self.weight.data.copy() + orig_weights = base_layer.weight.data.copy() orig_weights += self.get_delta_weight(active_adapter) if not torch.isfinite(orig_weights).all(): @@ -382,9 +388,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = orig_weights + base_layer.weight.data = orig_weights else: - self.weight.data += self.get_delta_weight(active_adapter) + base_layer.weight.data += self.get_delta_weight(active_adapter) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -394,7 +400,7 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_embedding_A.keys(): - self.weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -430,28 +436,28 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor - def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = self.weight if weight is None else weight + def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + base_layer = self.get_base_layer() return F.embedding( input, weight, - padding_idx=self.padding_idx, - max_norm=self.max_norm, - norm_type=self.norm_type, - scale_grad_by_freq=self.scale_grad_by_freq, - sparse=self.sparse, + padding_idx=base_layer.padding_idx, + max_norm=base_layer.max_norm, + norm_type=base_layer.norm_type, + scale_grad_by_freq=base_layer.scale_grad_by_freq, + sparse=base_layer.sparse, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: # TODO: no dtype conversion here, unlike in Linear, is that correct? if self.disable_adapters: if self.merged: self.unmerge() - result = self._embed(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._embed(x) + result = self.base_layer(x, *args, **kwargs) else: - result = self._embed(x) + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_embedding_A: continue @@ -463,36 +469,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep -class Conv2d(nn.Conv2d, LoraLayer): + +class Conv2d(nn.Module, LoraLayer): # Lora implemented in a conv2d layer def __init__( self, + base_layer: nn.Module, adapter_name: str, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - stride: Union[int, Tuple[int]] = 1, - padding: Union[int, Tuple[int]] = 0, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + init_lora_weights: bool = True, **kwargs, ) -> None: - init_lora_weights = kwargs.pop("init_lora_weights", True) - self._init_empty_weights(nn.Conv2d, in_channels, out_channels, kernel_size, stride=stride, padding=padding) - - LoraLayer.__init__( - self, - in_features=in_channels, - out_features=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) + super().__init__() + LoraLayer.__init__(self, base_layer) + self._active_adapter = adapter_name self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.set_adapter(adapter_name) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ @@ -518,19 +516,20 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. - orig_weights = self.weight.data.copy() + orig_weights = base_layer.weight.data.copy() orig_weights += self.get_delta_weight(active_adapter) if not torch.isfinite(orig_weights).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight.data = orig_weights + base_layer.weight.data = orig_weights else: - self.weight.data += self.get_delta_weight(active_adapter) + base_layer.weight.data += self.get_delta_weight(active_adapter) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -540,7 +539,7 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -566,7 +565,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: weight_B = weight_B.float() # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 - if self.weight.size()[2:4] == (1, 1): + if self.get_base_layer().weight.size()[2:4] == (1, 1): # conv2d 1x1 output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( 3 @@ -590,28 +589,17 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor - def _conv2d(self, input: torch.Tensor) -> torch.Tensor: - return F.conv2d( - input, - self.weight, - bias=self.bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: previous_dtype = x.dtype if self.disable_adapters: if self.merged: self.unmerge() - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) else: - result = self._conv2d(x) + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -624,3 +612,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: result = result.to(previous_dtype) return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 6b76ad9b6c..a5b7735ce3 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -22,7 +22,6 @@ from typing import List, Optional import torch -from torch import nn from tqdm import tqdm from transformers.pytorch_utils import Conv1D @@ -108,6 +107,8 @@ class LoraModel(BaseTuner): - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. """ + prefix: str = "lora_" + def __init__(self, model, config, adapter_name) -> None: super().__init__(model, config, adapter_name) @@ -165,7 +166,7 @@ def _create_and_replace( kwargs["gptq_quantization_config"] = quantization_config # TODO: better deal with that - if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): + if isinstance(target, Conv2d): target.update_layer_conv2d( adapter_name, r, @@ -173,7 +174,7 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) - elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding): + elif isinstance(target, Embedding): target.update_layer_embedding( adapter_name, r, @@ -181,8 +182,7 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) - - elif isinstance(target, LoraLayer): + elif isinstance(target, Linear): target.update_layer( adapter_name, r, @@ -197,8 +197,7 @@ def _create_and_replace( new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) - @staticmethod - def _replace_module(parent, child_name, new_module, child): + def _replace_module(self, parent, child_name, new_module, child): setattr(parent, child_name, new_module) # It's not necessary to set requires_grad here, as that is handled by # _mark_only_adapters_as_trainable @@ -206,10 +205,7 @@ def _replace_module(parent, child_name, new_module, child): # child layer wraps the original module, unpack it if hasattr(child, "base_layer"): child = child.base_layer - elif hasattr(child, "quant_linear_module"): - child = child.quant_linear_module - # TODO: layers with base_layer don't need the weight to be copied, as they have a reference already if not hasattr(new_module, "base_layer"): new_module.weight = child.weight if hasattr(child, "bias"): @@ -224,14 +220,13 @@ def _replace_module(parent, child_name, new_module, child): # dispatch to correct device for name, module in new_module.named_modules(): - if "lora_" in name: - module.to(child.weight.device) - if "ranknum" in name: - module.to(child.weight.device) + if (self.prefix in name) or ("ranknum" in name): + weight = child.qweight if hasattr(child, "qweight") else child.weight + module.to(weight.device) def _mark_only_adapters_as_trainable(self) -> None: for n, p in self.model.named_parameters(): - if "lora_" not in n: + if self.prefix not in n: p.requires_grad = False for active_adapter in self.active_adapters: @@ -257,9 +252,13 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) - bias = kwargs.pop("bias", False) - if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() eightbit_kwargs.update( { @@ -269,8 +268,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "index": target.index, } ) - new_module = Linear8bitLt(adapter_name, target, **eightbit_kwargs) - elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): + new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() fourbit_kwargs.update( { @@ -279,47 +278,37 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "quant_type": target.weight.quant_type, } ) - new_module = Linear4bit(adapter_name, target, **fourbit_kwargs) - elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): - new_module = QuantLinear(adapter_name, target, **kwargs) + new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear): + new_module = QuantLinear(target, adapter_name, **kwargs) target.weight = target.qweight - elif isinstance(target, torch.nn.Embedding): + elif isinstance(target_base_layer, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop("fan_in_fan_out", None) - in_features, out_features = target.num_embeddings, target.embedding_dim - new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) - elif isinstance(target, torch.nn.Conv2d): - out_channels, in_channels = target.weight.size()[:2] - kernel_size = target.weight.size()[2:] - stride = target.stride - padding = target.padding - new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) - else: - if isinstance(target, torch.nn.Linear): - in_features, out_features = target.in_features, target.out_features - if kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + new_module = Embedding(target, adapter_name, **embedding_kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + new_module = Conv2d(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." ) - kwargs["is_target_conv_1d_layer"] = True - if not kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True - else: - raise ValueError( - f"Target module {target} is not supported. Currently, only the following modules are supported: " - "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + new_module = Linear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." ) - new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." + ) return new_module @@ -388,60 +377,20 @@ def _unload_and_optionally_merge( if getattr(self.model, "quantization_method", None) == "gptq": raise ValueError("Cannot merge LORA layers when the model is gptq quantized") - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] desc = "Unloading " + ("and merging " if merge else "") + "model" for key in tqdm(key_list, disable=not progressbar, desc=desc): try: parent, target, target_name = _get_submodules(self.model, key) except AttributeError: continue - if isinstance(target, LoraLayer): - if isinstance(target, nn.Embedding): - new_module = torch.nn.Embedding(target.in_features, target.out_features) - elif isinstance(target, nn.Conv2d): - new_module = torch.nn.Conv2d( - target.in_channels, - target.out_channels, - kernel_size=target.kernel_size, - stride=target.stride, - padding=target.padding, - dilation=target.dilation, - ) - elif is_bnb_available() and isinstance(target, Linear8bitLt): - bias = target.base_layer.bias is not None - new_module = bnb.nn.Linear8bitLt( - target.in_features, - target.out_features, - bias=bias, - has_fp16_weights=target.base_layer.state.has_fp16_weights, - memory_efficient_backward=target.base_layer.state.memory_efficient_backward, - threshold=target.base_layer.state.threshold, - index=target.base_layer.index, - device=target.base_layer.weight.device, - ) - elif is_bnb_4bit_available() and isinstance(target, Linear4bit): - bias = target.base_layer.bias is not None - new_module = bnb.nn.Linear4bit( - target.in_features, - target.out_features, - bias=bias, - compute_dtype=target.base_layer.compute_dtype, - compress_statistics=target.base_layer.weight.compress_statistics, - quant_type=target.base_layer.weight.quant_type, - device=target.base_layer.weight.device, - ) - else: - bias = target.bias is not None - if getattr(target, "is_target_conv_1d_layer", False): - new_module = Conv1D(target.out_features, target.in_features) - else: - new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) + + if hasattr(target, "base_layer"): if merge: target.merge(safe_merge=safe_merge, adapter_names=adapter_names) - self._replace_module(parent, target_name, new_module, target) - - # save any additional trainable modules part of `modules_to_save` - if isinstance(target, ModulesToSaveWrapper): + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` setattr(parent, target_name, target.modules_to_save[target.active_adapter]) return self.model @@ -543,7 +492,7 @@ def add_weighted_adapter( # Do we really need that? _freeze_adapter(self.model, adapter_name) - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, LoraLayer): @@ -667,7 +616,7 @@ def delete_adapter(self, adapter_name: str): raise ValueError(f"Adapter {adapter_name} does not exist") del self.peft_config[adapter_name] - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] new_adapter = None for key in key_list: _, target, _ = _get_submodules(self.model, key) diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index d3085c4831..5865887506 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import warnings from abc import abstractmethod from dataclasses import dataclass, field -from itertools import chain -from typing import Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union import torch import torch.nn as nn @@ -58,14 +56,15 @@ class LycorisConfig(PeftConfig): ) -class LycorisLayer(BaseTunerLayer, nn.Module): +class LycorisLayer(BaseTunerLayer): r""" A base layer for LyCORIS like adapters """ # adapter_layer_names needs to be defined on the child class other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout") - def __init__(self): + def __init__(self, base_layer: nn.Module) -> None: + self.base_layer = base_layer self.r = {} self.alpha = {} self.scaling = {} @@ -93,48 +92,20 @@ def _init_empty_weights(self, cls, *args, **kwargs) -> None: cls.__init__(self, *args, device="meta", **kwargs) self.to_empty(device=final_device) - def _op(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - @abstractmethod def create_adapter_parameters(self, adapter_name: str, r: int, **kwargs): ... - def forward(self, x: torch.Tensor) -> torch.Tensor: - previous_dtype = x.dtype - - if self.disable_adapters: - if self.merged: - self.unmerge() - result = self._op(x, self.weight) - elif self.merged: - result = self._op(x, self.weight) - else: - # Get base weights - weight = self.weight.data - - # Execute all the adapters - for active_adapter in self.active_adapters: - if active_adapter not in self._available_adapters: - continue - - module_dropout = self.module_dropout[active_adapter] - - # Modify current execution weights - if (not self.training) or (self.training and torch.rand(1) > module_dropout): - weight = weight + self.get_delta_weight(active_adapter) - - # Perform actual operation - result = self._op(x, weight) - - result = result.to(previous_dtype) - return result + # TODO: refactor LoRA to use the same approach + @abstractmethod + def _get_delta_activations(self, adapter_name: str, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Activations added on top of the base layer output (i.e. after the base layer forward pass)""" @abstractmethod def get_delta_weight(self, adapter_name: str) -> torch.Tensor: ... - def merge(self, adapter_names: Optional[List[str]] = None) -> None: + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: if self.merged: warnings.warn( f"Already following adapters were merged {','.join(self.merged_adapters)}. " @@ -145,7 +116,20 @@ def merge(self, adapter_names: Optional[List[str]] = None) -> None: for active_adapter in adapter_names: if active_adapter in self._available_adapters: - self.weight.data += self.get_delta_weight(active_adapter) + base_layer = self.get_base_layer() + + if safe_merge: + orig_weights = base_layer.weight.data + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) self.merged_adapters.append(active_adapter) @abstractmethod @@ -175,7 +159,7 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self._available_adapters: - self.weight.data -= self.get_delta_weight(active_adapter) + self.base_layer.weight.data -= self.get_delta_weight(active_adapter) def unscale_layer(self, scale=None) -> None: for active_adapter in self.active_adapters: @@ -214,6 +198,7 @@ def __getattr__(self, name: str): def _check_target_module_exists(config, key): return check_target_module_exists(config, key) + @abstractmethod def _create_and_replace( self, config: LycorisConfig, @@ -224,68 +209,47 @@ def _create_and_replace( current_key, **optional_kwargs, ): - """ - A private method to create and replace the target module with the adapter module. - """ - - # Regexp matching - Find key which matches current target_name in patterns provided - pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) - target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name) - - kwargs = config.to_dict() - kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) - kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha) - - if isinstance(target, LycorisLayer): - target.update_layer(adapter_name, **kwargs) - else: - new_module = self._create_new_module(config, adapter_name, target, **kwargs) - self._replace_module(parent, target_name, new_module, target) + ... @classmethod def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn.Module, **kwargs) -> LycorisLayer: # Find corresponding subtype of provided target module new_module_cls = None for subtype, target_cls in cls.layers_mapping.items(): - if isinstance(target, subtype): + if ( + hasattr(target, "base_layer") + and isinstance(target.get_base_layer(), subtype) + and isinstance(target, BaseTunerLayer) + ): + # nested tuner layers are allowed + new_module_cls = target_cls + break + elif isinstance(target, subtype): new_module_cls = target_cls break # We didn't find corresponding type, so adapter for this layer is not supported if new_module_cls is None: + supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys()) raise ValueError( - f"Target module not found, currently only adapters for {', '.join([x.__name__ for x in cls.modules_mapping.keys()])} are supported" + f"Target module of type {type(target)} not supported, " + f"currently only adapters for {supported_modules} are supported" ) - if isinstance(target, torch.nn.Conv2d): - new_module = new_module_cls( - target.in_channels, - target.out_channels, - target.weight.size()[2:], - stride=target.stride, - padding=target.padding, - dilation=target.dilation, - groups=target.groups, - bias=target.bias is not None, - padding_mode=target.padding_mode, - device=target.weight.device, - dtype=target.weight.dtype, - adapter_name=adapter_name, - **kwargs, - ) - elif isinstance(target, torch.nn.Linear): - new_module = new_module_cls( - target.in_features, - target.out_features, - bias=target.bias is not None, - device=target.weight.device, - dtype=target.weight.dtype, - adapter_name=adapter_name, - **kwargs, - ) + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Conv2d): + new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) + else: + supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys()) raise ValueError( - "Target module not found, currently only adapters for nn.Linear and nn.Conv2d are supported" + f"Target module of type {type(target)} not supported, " + f"currently only adapters for {supported_modules} are supported" ) return new_module @@ -305,12 +269,17 @@ def _replace_module(self, parent, child_name, new_module, child): setattr(parent, child_name, new_module) # It's not necessary to set requires_grad here, as that is handled by # _mark_only_adapters_as_trainable - new_module.weight = child.weight - if hasattr(child, "bias"): - new_module.bias = child.bias + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias if getattr(child, "state", None) is not None: - new_module.state = child.state + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state new_module.to(child.weight.device) # dispatch to correct device @@ -324,47 +293,30 @@ def _set_adapter_layers(self, enabled=True): module.enable_adapters(enabled) def _unload_and_optionally_merge( - self, merge=True, progressbar: bool = False, adapter_names: Optional[List[str]] = None + self, + merge: bool = True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[List[str]] = None, ): if merge: if getattr(self.model, "quantization_method", None) == "gptq": raise ValueError("Cannot merge LOHA layers when the model is gptq quantized") - key_list = [key for key, _ in self.model.named_modules() if "hada" not in key] + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] desc = "Unloading " + ("and merging " if merge else "") + "model" for key in tqdm(key_list, disable=not progressbar, desc=desc): try: parent, target, target_name = _get_submodules(self.model, key) except AttributeError: continue - if isinstance(target, LycorisLayer): - if isinstance(target, nn.Conv2d): - new_module = torch.nn.Conv2d( - target.in_channels, - target.out_channels, - kernel_size=target.kernel_size, - stride=target.stride, - padding=target.padding, - dilation=target.dilation, - ) - elif isinstance(target, nn.Linear): - bias = target.bias is not None - new_module = torch.nn.Linear( - target.in_features, - target.out_features, - bias=bias, - device=target.weight.device, - ) - else: - raise ValueError( - "Cannot convert current module to torch module, currently only adapters for nn.Linear and nn.Conv2d are supported" - ) - if merge: - target.merge(adapter_names=adapter_names) - self._replace_module(parent, target_name, new_module, target) - # save any additional trainable modules part of `modules_to_save` - if isinstance(target, ModulesToSaveWrapper): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` setattr(parent, target_name, target.modules_to_save[target.active_adapter]) return self.model @@ -375,8 +327,34 @@ def enable_adapter_layers(self): def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) - def merge_and_unload(self, progressbar: bool = False, adapter_names: Optional[List[str]] = None): - return self._unload_and_optionally_merge(progressbar=progressbar, adapter_names=adapter_names) + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ): + r""" + This method merges the adapter layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self): + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) def set_adapter(self, adapter_name): for module in self.model.modules(): diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 004352237f..d9616d29d6 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -20,6 +20,7 @@ from abc import ABC, abstractmethod from typing import Any, Union +import torch from torch import nn from peft.utils import COMMON_LAYERS_PATTERN @@ -312,6 +313,34 @@ class BaseTunerLayer(ABC): # List all merged adapters merged_adapters: list[str] = [] + def get_base_layer(self) -> nn.Module: + """ + (Recursively) get the base_layer. + + This is necessary for the case that the tuner layer wraps another tuner layer. + + """ + base_layer = self + while hasattr(base_layer, "base_layer"): + base_layer = base_layer.base_layer + return base_layer + + @property + def weight(self) -> torch.Tensor: + # This is required for some transformers code, e.g. for T5, weight is accessed as: + # self.wo.weight + # where "wo" is the adapter layer. + # https://github.com/huggingface/transformers/blob/78f6ed6c70b29c1560780e3869a7ad4c6b3d2710/src/transformers + # /models/t5/modeling_t5.py#L292 + base_layer = self.get_base_layer() + if hasattr(base_layer, "qweight"): + # QuantLinear + weight = base_layer.qweight + else: + # Other layers + weight = base_layer.weight + return weight + def merge(self, *args) -> None: raise NotImplementedError diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 4f64fa4487..50f22a5523 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -277,8 +277,22 @@ def _set_trainable(model, adapter_name): def _set_adapter(model, adapter_name): + def check_adapter_name(adapter_name): + if isinstance(adapter_name, str): + return adapter_name + + # adapter_name is a list of str + if len(adapter_name) > 1: + raise ValueError("Only one adapter can be set at a time for modules_to_save") + elif len(adapter_name) == 0: + raise ValueError("Please specify at least one adapter to set") + adapter_name = adapter_name[0] + return adapter_name + for module in model.modules(): if isinstance(module, ModulesToSaveWrapper): + # only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care + adapter_name = check_adapter_name(adapter_name) module.set_adapter(adapter_name) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 14ae59b05c..347df218b2 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -466,6 +466,20 @@ def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwa def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) + @parameterized.expand(TEST_CASES) + def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwargs): + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model.eval() + with torch.no_grad(): + output = model(**X) + self.assertTrue(torch.isfinite(output).all()) + @parameterized.expand(TEST_CASES) def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs): # An explicit test that when using LoRA on a custom model, only the LoRA parameters are updated during training @@ -546,7 +560,9 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c @parameterized.expand(TEST_CASES) def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing() - model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval() + outputs_base = model(**X) + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -555,6 +571,8 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): model.eval() outputs_before = model(**X) + self.assertTrue(torch.allclose(outputs_base, outputs_before)) + model.train() # EmbConv1D is slow to learn for some reason lr = 0.01 if model_id != "EmbConv1D" else 1.0 @@ -732,6 +750,67 @@ def test_non_existing_model_card(self): # rough check that the model card is pre-filled self.assertGreater(len(model_card), 1000) + @parameterized.expand( + [ + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoHaConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), + ] + ) + def test_adapter_name_makes_no_difference(self, config0): + # It should not matter whether we use the default adapter name or a custom one + model_cls = MLP + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + + # base model + torch.manual_seed(0) + base_model = model_cls().eval().to(self.torch_device) + output_base = base_model(input) + + # default name + torch.manual_seed(0) + base_model = model_cls().eval().to(self.torch_device) + torch.manual_seed(0) + peft_model_default = get_peft_model(base_model, config0, adapter_name="default").eval().to(self.torch_device) + output_default = peft_model_default(input) + sd_default = peft_model_default.state_dict() + + # custom name 1 + torch.manual_seed(0) + base_model = model_cls().eval().to(self.torch_device) + torch.manual_seed(0) + peft_model_custom1 = get_peft_model(base_model, config0, adapter_name="adapter").eval().to(self.torch_device) + output_custom1 = peft_model_custom1(input) + sd_custom1 = peft_model_custom1.state_dict() + + # custom name 2 + torch.manual_seed(0) + base_model = model_cls().eval().to(self.torch_device) + torch.manual_seed(0) + peft_model_custom2 = ( + get_peft_model(base_model, config0, adapter_name="other-name").eval().to(self.torch_device) + ) + output_custom2 = peft_model_custom2(input) + sd_custom2 = peft_model_custom2.state_dict() + + assert len(sd_default) == len(sd_custom1) == len(sd_custom2) + for key in sd_default: + key1 = key.replace("default", "adapter") + key2 = key.replace("default", "other-name") + assert key1 in sd_custom1 + assert key2 in sd_custom2 + for k0, k1, k2 in zip(sd_default, sd_custom1, sd_custom2): + assert torch.allclose(sd_default[k0], sd_custom1[k1]) + assert torch.allclose(sd_default[k0], sd_custom2[k2]) + + self.assertFalse(torch.allclose(output_base, output_default)) + self.assertFalse(torch.allclose(output_base, output_custom1)) + self.assertFalse(torch.allclose(output_base, output_custom2)) + self.assertTrue(torch.allclose(output_custom1, output_custom2)) + self.assertTrue(torch.allclose(output_default, output_custom1)) + class TestMultiRankAdapter(unittest.TestCase): """Tests related to multirank LoRA adapters""" @@ -808,8 +887,9 @@ def test_repr_lora_linear(self): config = LoraConfig(target_modules=["lin0"]) model = get_peft_model(MLP(), config) print_output = repr(model.model.lin0) - self.assertTrue(print_output.startswith("Linear")) - self.assertTrue("in_features=10, out_features=20" in print_output) + self.assertTrue(print_output.startswith("lora.Linear")) + self.assertTrue("in_features=10" in print_output) + self.assertTrue("out_features=20" in print_output) self.assertTrue("lora_A" in print_output) self.assertTrue("lora_B" in print_output) self.assertTrue("default" in print_output) @@ -818,7 +898,7 @@ def test_repr_lora_embedding(self): config = LoraConfig(target_modules=["emb"]) model = get_peft_model(ModelEmbConv1D(), config) print_output = repr(model.model.emb) - self.assertTrue(print_output.startswith("Embedding")) + self.assertTrue(print_output.startswith("lora.Embedding")) self.assertTrue("100, 5" in print_output) self.assertTrue("lora_embedding_A" in print_output) self.assertTrue("lora_embedding_B" in print_output) @@ -828,8 +908,9 @@ def test_repr_lora_conv1d(self): config = LoraConfig(target_modules=["conv1d"]) model = get_peft_model(ModelEmbConv1D(), config) print_output = repr(model.model.conv1d) - self.assertTrue(print_output.startswith("Linear")) - self.assertTrue("in_features=5, out_features=1" in print_output) + self.assertTrue(print_output.startswith("lora.Linear")) + self.assertTrue("in_features=5" in print_output) + self.assertTrue("out_features=1" in print_output) self.assertTrue("lora_A" in print_output) self.assertTrue("lora_B" in print_output) self.assertTrue("default" in print_output) @@ -838,7 +919,7 @@ def test_repr_lora_conv2d(self): config = LoraConfig(target_modules=["conv2d"]) model = get_peft_model(ModelConv2D(), config) print_output = repr(model.model.conv2d) - self.assertTrue(print_output.startswith("Conv2d")) + self.assertTrue(print_output.startswith("lora.Conv2d")) self.assertTrue("5, 10" in print_output) self.assertTrue("kernel_size=(3, 3)" in print_output) self.assertTrue("stride=(1, 1)" in print_output) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index a6b3d16d4d..ab49c3eea5 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -245,6 +245,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_adalora_and_gpt2, diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index ce09fc6247..2b4331ae20 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -156,6 +156,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 2c4a4f5b2b..e3a7040e1e 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -968,12 +968,12 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA", "ADALORA"): + if config.peft_type not in ("LORA", "ADALORA", "IA3"): with self.assertRaises(AttributeError): model = model.unload() else: dummy_input = self.prepare_inputs_for_testing() - logits_with_lora = model(**dummy_input)[0] + logits_with_adapter = model(**dummy_input)[0] transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) logits_transformers = transformers_model(**dummy_input)[0] @@ -982,7 +982,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = model.unload() logits_unload = model(**dummy_input)[0] - self.assertFalse(torch.allclose(logits_with_lora, logits_unload, atol=1e-10, rtol=1e-10)) + self.assertFalse(torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10)) self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)) def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs): @@ -992,13 +992,14 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw adapter_list = ["adapter1", "adapter_2", "adapter_3"] weight_list = [0.5, 1.5, 1.5] - model = self.transformers_class.from_pretrained(model_id) config = config_cls( base_model_name_or_path=model_id, **config_kwargs, ) if not isinstance(config, (LoraConfig)): return + + model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config, adapter_list[0]) model.add_adapter(adapter_list[1], config) model.add_adapter(adapter_list[2], replace(config, r=20)) @@ -1113,7 +1114,7 @@ def get_output(model): # must be False if isinstance(peft_model, StableDiffusionPipeline): # for SD, check that most pixels have different values - self.assertTrue((output_before != output_peft).float().mean() > 0.9) + self.assertTrue((output_before != output_peft).float().mean() > 0.8) else: self.assertFalse(torch.allclose(output_before, output_peft))