Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Nov 9, 2023

Refactor all tuners (where it applies) to use the "base layer pattern". This means that the adapter layer will always hold a reference to the original layer that it modifies. This pattern is already partly used (e.g. LoRA bnb, gptq layers), now it is consistently used everywhere it makes sense.

This PR is a companion PR to #1069, where I first added these changes. They are now extracted to a separate PR to make code review easier and to advance more quickly.

Implementation

The main change is that the adapter layer wraps the original layer and calls forward on that layer, instead of doing stuff like this:

F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

which completely circumvents the call to the target layer's forward method. With the base layer pattern, we now call the target layer's forward method. Therefore, if the target layer is another adapter layer (which will be crucial for mixed adapters), we call its forward method correctly. Also, this should allow passing extra arguments, like lora_scale to forward.

This change has the nice side benefit that we no longer need to use _init_empty_weights -- in fact, we don't initialize any of the target layer's weights anymore, since we have a reference to it. There is thus no risk of having slow but superfluous initialization of layers.

Moreover, I could greatly simplify merge_and_unload by just using the base_layer instead of having to create a completely new layer. I haven't measured it, but this should speed up the operation considerably.

Note that same as for the bnb layers, this should be backwards incompatible, since the adapter weights and their state_dicts are not affected by this change.

Somewhat unrelated changes

  1. During debugging, I got very annoyed with the fact that the reprs of adapter layers and normal PyTorch layers are hard to distinguish, e.g. the type is just "Linear". Now, for adapter layers, it is prefixed by the adapter type, e.g. "lora.Linear".
  2. For LoHa and LoKr, I had to change the init of weights when using init_weights=False. This is because of what is discussed in Numerical instabilities with LoHa #1058.

TODOs

Before merging this, I would like to do some regression tests to ensure that everything still works as expected. This is especially important to ensure that old adapter checkpoints can still be loaded. In theory, this should work for the same reason as discussed in #994 but better be safe than sorry.

This is a POC to show how we could achieve mixing adapter types such as
LoRA and LoKr.

Description

The very general idea is that we can already mix multiple adapters of
the same type, e.g. add two LoRA adapters, but right now we fail when
trying to mix different types. This restriction has been lifted by
adding a new class PeftMixedModel which deals with different adapter
types.

The usage looks something like this:

    base_model = ...
    config0 = LoraConfig(...)
    # set mixed=True
    peft_model = get_peft_model(base_model, config0, mixed=True)
    config1 = LoHaConfig(...)
    peft_model.add_adapter(config1, "other")
    peft_model.set_adapter(["default", "other"])

At this point, both adapters are active at the same time.

Existing code should not be affected by this change, since users need to
opt into this behavior by setting mixed=True.

Also interesting is that this method can be used for a single adapter
type but with very different configs. Right now, we have limited support
for that (e.g. for LoRA, different r values by using rank_pattern), but
with this, we don't need to special case the differing arguments anymore.

Implementation

Apart from adding the new PeftMixedModel class to replace PeftModel, I
added a new class LycorisModel which replaces LoraModel, LoHaModel etc.
This class checks the config type and then uses the corresponding
LoraModel, LoHaModel etc. to create the adapter.

Another crucial change I had to make was to adopt the "base layer
pattern". This is the pattern that was, for instance, used to speed up
initialization in LoRA bnb layers in PR huggingface#994.

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

    F.linear(
        input, transpose(self.weight, self.fan_in_fan_out)
    )

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter layer,
we call its forward method correctly.

This change has the nice side benefit that we no longer need to use
_init_empty_weight -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it.

Note that same as for the bnb layers, this should not be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear".

TODOs

- [ ] For now, I only added this capability for LoRA and LoHa as a POC.
  It needs to be added to LoKr and AdaLora too.
- [ ] The unit tests are very rudimentary right now, only a simple model
  is tested in two settings.
- [ ] There is no documentation so far.
- [ ] I'm not yet sure if the same logic can be applied to IA³ or if it
  may fail because IA³ can apply its scaling to the input, not the output
- [ ] It is currently not possible to represent a mixed adapter model as
  a single config. I think we can come up with a solution but I don't
  think it is necessary for a first version of this.
Docs don't build otherwise...
Some tests are still failing, probably because of mixups when AdaLora
checks are performed, because it checks for Lora instances, so it might
result in false positives if the tuners are mixed.
Seems to work only on newer Python versions.
In update_layer, not in __init__.
Decreased tolerance, as one test currently fails on Windows, presumably
due to precision.

Also, better test function names by providing a name_func.
As side effects, added the prefix attribute on LoRA for consistency and
added safe merging on LoHa, LoKr.
Add test for deeply nested models
This should be much quicker because we don't create new layers but
instead simply return the existing base layer.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Nov 10, 2023

Update: Based on my regression testing PR (#1115), I created a bunch of regression tests to check this PR against v0.6.1. This means that I created adapters with v0.6.1 and stored them, as well as the model output. Then I switched to this PR, loaded the adapters successfully and also produced the exact same output.

This confirms that the PR does not break existing code or checkpoints, which is what we expected based on how the changes are implemented, and based on the same refactor on bnb layers we merged earlier. This gives me high confidence that this PR is not BC breaking.

Here is the list of settings that I tested:

  • adalora_mlp
  • adalora_opt-350m
  • ia3_conv2d
  • ia3_mlp
  • ia3_no_ff_mlp
  • ia3_opt-350m
  • lora_conv2d
  • lora_emb_conv1d
  • lora_mlp
  • lora_mlp_modules_to_save
  • lora_opt-350m
  • lora_opt-350m_bnb_4bit
  • lora_opt-350m_bnb_8bit

Note that I couldn't test a couple of settings:

  1. LoHa and LoKr: When using init_weights=False, they are initialized with torch.empty, thus results are not reproducible, see Numerical instabilities with LoHa #1058. In this very PR, that behavior is changed so that in the future, we get reproducible results with those tuners.
  2. AdaLoRA with 4bit and 8bit. For 8bit, I got TypeError: unsupported operand type(s) for +=: 'dict' and 'Tensor', for 4bit I got the issue described in AdaLora + bnb not working #1113.

It would have been nice to include those, but I see very little danger that we get BC breakage in those.

Update2: I did a quick check regarding the speed of unloading the model with this PR. For opt-350m, I got:

  • 0.2966 seconds
  • 0.3061 seconds
  • 0.2943 seconds

on main and

  • 0.0189 seconds
  • 0.0183 seconds
  • 0.0181 seconds

on this PR. This means we get a 15x speedup for unloading. Code:

import time
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to("cuda")
config = LoraConfig(
    r=8,
    init_lora_weights=False,
)
torch.manual_seed(42)
model = get_peft_model(model, config)
tic = time.perf_counter()
model = model.merge_and_unload()
toc = time.perf_counter()
print(f"Unloading the model took {toc - tic:0.4f} seconds")

@BenjaminBossan BenjaminBossan marked this pull request as ready for review November 10, 2023 17:08
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pacman100 your comments should be addressed now.

Furthermore, added a couple of missing things:

  • AdaLora bnb was finished
  • IA3 was finished, now using base layer pattern throughout
    • IA3 init should now be fast
    • IA3 now also supports unload()
    • IA3 unloading should now be fast
  • A few repr were added, should be consistent now
  • GPTQ should now work correctly

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Thanks for the review. You brought up some good comments about backwards compatibility in case users use certain classes directly (not through get_peft_model etc.). I think it's basically impossible to make the changes suggested here without breaking this very low level API. I believe there will be very few, if any, users who do this (maybe transformers integration??). Also note that we already adopted this pattern partly (e.g. bnb LoRA layers) and so far there have been no complaints about it. Therefore, my suggestion would be to accept this possible breaking change.

However, what we could do is to add a migration guide somewhere (maybe docs?). There, I could write down what users who actually use these classes directly should be doing to make their code work. Do you think this will be good enough?

from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit

# Mapping of tuners that support direct plugging
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not being used anywhere. Removing this is not related to this PR specifically.

if is_bnb_4bit_available():

class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer):
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically true, although I'm not sure how much we can consider inheritance structure of our classes to be "public API". I think we can assume that very few users do this type of check (if any), and I hope those few expert users can quickly figure out what is wrong. If you are aware of anyone doing this, let me know and I can prepare something like a migration guide.

adapter_name,
in_features,
out_features,
base_layer: torch.nn.Module,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as above, though here we could change the order of adapter_name and base_layer, which would, however, be inconsistent with the other layers. Nevertheless, even if we change the order, a caller would still get an error here because we now have an additional argument, the base_layer, that they don't pass. So either way, there would be an error.

self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_adapter is already called inside of update_layer just above, so it was in fact called twice in a row.

Comment on lines +185 to +187
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes, I could. The only issue is that we have this method now more than a dozen times, so I would have to add the same comment very often. If you still think it's good to have, or have a better idea, LMK and I'll change it!

There was a bit of unnecessary back and forth transposing of the
weights, instead it is simpler to transpose the IA3 values.
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for the great work, LGTM! 🚀

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive and inspiring work @BenjaminBossan !

@BenjaminBossan BenjaminBossan merged commit 5a3a5ac into huggingface:main Nov 16, 2023
@BenjaminBossan BenjaminBossan deleted the refactor-base-layer-pattern branch November 16, 2023 11:45
emmahone pushed a commit to emmahone/peft that referenced this pull request Nov 27, 2023
This is important if we have nested adapter layers. This was an overlook
during the refactoring huggingface#1106.
BenjaminBossan pushed a commit that referenced this pull request Jan 2, 2024
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jan 11, 2024
Resolves huggingface#1345

See also huggingface#1294 for a similar (but incomplete) fix.

This commit fixes the setting of the adapter name on a couple of
quantized layers that was accidentally removed in huggingface#1106. This affects
users who use a non-default adapter name when they want to train these
layers.
BenjaminBossan added a commit that referenced this pull request Jan 12, 2024
Resolves #1345

See also #1294 for a similar (but incomplete) fix.

This commit fixes the setting of the adapter name on a couple of
quantized layers that was accidentally removed in #1106. This affects
users who use a non-default adapter name when they want to train these
layers.

---------

Co-authored-by: Sourab Mangrulkar <[email protected]>
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Description

Refactor all tuners (where it applies, i.e. not prompt tuning) to use
the "base layer pattern". This means that the adapter layer will always
hold a reference to the original layer that it modifies. This pattern is
already partly used (e.g. LoRA bnb, gptq layers), now it is consistently
used everywhere when applicable.

This PR is a companion PR to huggingface#1069, where I first added these changes.
They are now extracted to a separate PR to make code review easier and
to advance more quickly.

Implementation

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter
layer (which will be crucial for mixed adapters), we call its forward
method correctly. Also, this should allow passing extra arguments, like
lora_scale to forward.

This change has the nice side benefit that we no longer need to use
_init_empty_weights -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it. There is thus
no risk of having slow but superfluous initialization of layers.

Moreover, I could greatly simplify merge_and_unload by just using the
base_layer instead of having to create a completely new layer. For
OPT-350m, this results in a 15x speedup.

Note that same as for the bnb layers, this should be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change. I used huggingface#1115 for regression testing.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear". This should have no further
implications except for the repr (e.g. state_dict remains unaffected).

For LoHa and LoKr, I had to change the init of weights when using
init_weights=False. This is because of what is discussed in Numerical
instabilities with LoHa huggingface#1058.

IA³ now has the unload method too.

LoHa and LoKr now support safe_merge=True when merging layers.

Migration guide

For 99% of users, the code should continue working as ususal, because
the API stays the same. Only low level details have been changed.

Code that relies on isinstance checks on specific PEFT classes may
break. E.g. the LoRA Linear layer no longer inherits from nn.Linear. It
is, however, still a BaseTunerLayer. The same logic applies for other
layer types like Conv2d and for other tuners like IA³.

To retrieve the base layer of an adapter layer, you should now call
module.get_base_layer() if you deal with a BaseTunerLayer. Don't rely on
something like module.weight being present (though it might be).
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
This is important if we have nested adapter layers. This was an overlook
during the refactoring huggingface#1106.
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Resolves huggingface#1345

See also huggingface#1294 for a similar (but incomplete) fix.

This commit fixes the setting of the adapter name on a couple of
quantized layers that was accidentally removed in huggingface#1106. This affects
users who use a non-default adapter name when they want to train these
layers.

---------

Co-authored-by: Sourab Mangrulkar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants