Skip to content

Commit 7a2f46e

Browse files
patrickvonplatenandres
authored andcommitted
[Lora] Seperate logic (huggingface#5809)
* [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better
1 parent 3b56e1d commit 7a2f46e

File tree

7 files changed

+219
-54
lines changed

7 files changed

+219
-54
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
AttnAddedKVProcessor2_0,
5858
SlicedAttnAddedKVProcessor,
5959
)
60-
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
60+
from diffusers.models.lora import LoRALinearLayer
6161
from diffusers.optimization import get_scheduler
6262
from diffusers.training_utils import unet_lora_state_dict
6363
from diffusers.utils import check_min_version, is_wandb_available
@@ -70,6 +70,39 @@
7070
logger = get_logger(__name__)
7171

7272

73+
# TODO: This function should be removed once training scripts are rewritten in PEFT
74+
def text_encoder_lora_state_dict(text_encoder):
75+
state_dict = {}
76+
77+
def text_encoder_attn_modules(text_encoder):
78+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
79+
80+
attn_modules = []
81+
82+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
83+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
84+
name = f"text_model.encoder.layers.{i}.self_attn"
85+
mod = layer.self_attn
86+
attn_modules.append((name, mod))
87+
88+
return attn_modules
89+
90+
for name, module in text_encoder_attn_modules(text_encoder):
91+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
92+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
93+
94+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
95+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
96+
97+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
98+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
99+
100+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
101+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
102+
103+
return state_dict
104+
105+
73106
def save_model_card(
74107
repo_id: str,
75108
images=None,

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
UNet2DConditionModel,
5151
)
5252
from diffusers.loaders import LoraLoaderMixin
53-
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
53+
from diffusers.models.lora import LoRALinearLayer
5454
from diffusers.optimization import get_scheduler
5555
from diffusers.training_utils import compute_snr, unet_lora_state_dict
5656
from diffusers.utils import check_min_version, is_wandb_available
@@ -63,6 +63,39 @@
6363
logger = get_logger(__name__)
6464

6565

66+
# TODO: This function should be removed once training scripts are rewritten in PEFT
67+
def text_encoder_lora_state_dict(text_encoder):
68+
state_dict = {}
69+
70+
def text_encoder_attn_modules(text_encoder):
71+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
72+
73+
attn_modules = []
74+
75+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
76+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
77+
name = f"text_model.encoder.layers.{i}.self_attn"
78+
mod = layer.self_attn
79+
attn_modules.append((name, mod))
80+
81+
return attn_modules
82+
83+
for name, module in text_encoder_attn_modules(text_encoder):
84+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
85+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
86+
87+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
88+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
89+
90+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
91+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
92+
93+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
94+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
95+
96+
return state_dict
97+
98+
6699
def save_model_card(
67100
repo_id: str,
68101
images=None,

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@
4040

4141
import diffusers
4242
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
43-
from diffusers.loaders import AttnProcsLayers
44-
from diffusers.models.attention_processor import LoRAAttnProcessor
43+
from diffusers.models.lora import LoRALinearLayer
4544
from diffusers.optimization import get_scheduler
4645
from diffusers.training_utils import compute_snr
4746
from diffusers.utils import check_min_version, is_wandb_available
@@ -54,6 +53,39 @@
5453
logger = get_logger(__name__, log_level="INFO")
5554

5655

56+
# TODO: This function should be removed once training scripts are rewritten in PEFT
57+
def text_encoder_lora_state_dict(text_encoder):
58+
state_dict = {}
59+
60+
def text_encoder_attn_modules(text_encoder):
61+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
62+
63+
attn_modules = []
64+
65+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
66+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
67+
name = f"text_model.encoder.layers.{i}.self_attn"
68+
mod = layer.self_attn
69+
attn_modules.append((name, mod))
70+
71+
return attn_modules
72+
73+
for name, module in text_encoder_attn_modules(text_encoder):
74+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
75+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
76+
77+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
78+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
79+
80+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
81+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
82+
83+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
84+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
85+
86+
return state_dict
87+
88+
5789
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
5890
img_str = ""
5991
for i, image in enumerate(images):
@@ -458,25 +490,43 @@ def main():
458490
# => 32 layers
459491

460492
# Set correct lora layers
461-
lora_attn_procs = {}
462-
for name in unet.attn_processors.keys():
463-
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
464-
if name.startswith("mid_block"):
465-
hidden_size = unet.config.block_out_channels[-1]
466-
elif name.startswith("up_blocks"):
467-
block_id = int(name[len("up_blocks.")])
468-
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
469-
elif name.startswith("down_blocks"):
470-
block_id = int(name[len("down_blocks.")])
471-
hidden_size = unet.config.block_out_channels[block_id]
472-
473-
lora_attn_procs[name] = LoRAAttnProcessor(
474-
hidden_size=hidden_size,
475-
cross_attention_dim=cross_attention_dim,
476-
rank=args.rank,
493+
unet_lora_parameters = []
494+
for attn_processor_name, attn_processor in unet.attn_processors.items():
495+
# Parse the attention module.
496+
attn_module = unet
497+
for n in attn_processor_name.split(".")[:-1]:
498+
attn_module = getattr(attn_module, n)
499+
500+
# Set the `lora_layer` attribute of the attention-related matrices.
501+
attn_module.to_q.set_lora_layer(
502+
LoRALinearLayer(
503+
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
504+
)
505+
)
506+
attn_module.to_k.set_lora_layer(
507+
LoRALinearLayer(
508+
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
509+
)
510+
)
511+
512+
attn_module.to_v.set_lora_layer(
513+
LoRALinearLayer(
514+
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
515+
)
516+
)
517+
attn_module.to_out[0].set_lora_layer(
518+
LoRALinearLayer(
519+
in_features=attn_module.to_out[0].in_features,
520+
out_features=attn_module.to_out[0].out_features,
521+
rank=args.rank,
522+
)
477523
)
478524

479-
unet.set_attn_processor(lora_attn_procs)
525+
# Accumulate the LoRA params to optimize.
526+
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
527+
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
528+
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
529+
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
480530

481531
if args.enable_xformers_memory_efficient_attention:
482532
if is_xformers_available():
@@ -491,8 +541,6 @@ def main():
491541
else:
492542
raise ValueError("xformers is not available. Make sure it is installed correctly")
493543

494-
lora_layers = AttnProcsLayers(unet.attn_processors)
495-
496544
# Enable TF32 for faster training on Ampere GPUs,
497545
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
498546
if args.allow_tf32:
@@ -517,7 +565,7 @@ def main():
517565
optimizer_cls = torch.optim.AdamW
518566

519567
optimizer = optimizer_cls(
520-
lora_layers.parameters(),
568+
unet_lora_parameters,
521569
lr=args.learning_rate,
522570
betas=(args.adam_beta1, args.adam_beta2),
523571
weight_decay=args.adam_weight_decay,
@@ -644,8 +692,8 @@ def collate_fn(examples):
644692
)
645693

646694
# Prepare everything with our `accelerator`.
647-
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
648-
lora_layers, optimizer, train_dataloader, lr_scheduler
695+
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
696+
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
649697
)
650698

651699
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -777,7 +825,7 @@ def collate_fn(examples):
777825
# Backpropagate
778826
accelerator.backward(loss)
779827
if accelerator.sync_gradients:
780-
params_to_clip = lora_layers.parameters()
828+
params_to_clip = unet_lora_parameters
781829
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
782830
optimizer.step()
783831
lr_scheduler.step()

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
UNet2DConditionModel,
5151
)
5252
from diffusers.loaders import LoraLoaderMixin
53-
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
53+
from diffusers.models.lora import LoRALinearLayer
5454
from diffusers.optimization import get_scheduler
5555
from diffusers.training_utils import compute_snr
5656
from diffusers.utils import check_min_version, is_wandb_available
@@ -63,6 +63,39 @@
6363
logger = get_logger(__name__)
6464

6565

66+
# TODO: This function should be removed once training scripts are rewritten in PEFT
67+
def text_encoder_lora_state_dict(text_encoder):
68+
state_dict = {}
69+
70+
def text_encoder_attn_modules(text_encoder):
71+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
72+
73+
attn_modules = []
74+
75+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
76+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
77+
name = f"text_model.encoder.layers.{i}.self_attn"
78+
mod = layer.self_attn
79+
attn_modules.append((name, mod))
80+
81+
return attn_modules
82+
83+
for name, module in text_encoder_attn_modules(text_encoder):
84+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
85+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
86+
87+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
88+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
89+
90+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
91+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
92+
93+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
94+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
95+
96+
return state_dict
97+
98+
6699
def save_model_card(
67100
repo_id: str,
68101
images=None,

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder):
88
deprecate(
99
"text_encoder_load_state_dict in `models`",
1010
"0.27.0",
11-
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
11+
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
1212
)
1313
state_dict = {}
1414

@@ -34,7 +34,7 @@ def text_encoder_attn_modules(text_encoder):
3434
deprecate(
3535
"text_encoder_attn_modules in `models`",
3636
"0.27.0",
37-
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
37+
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
3838
)
3939
from transformers import CLIPTextModel, CLIPTextModelWithProjection
4040

@@ -67,7 +67,6 @@ def text_encoder_attn_modules(text_encoder):
6767

6868
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
6969
if is_torch_available():
70-
from ..models.lora import text_encoder_lora_state_dict
7170
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
7271
from .unet import UNet2DConditionLoadersMixin
7372
from .utils import AttnProcsLayers

src/diffusers/loaders/lora.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@
4747

4848

4949
if is_transformers_available():
50-
from transformers import PreTrainedModel
50+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
5151

52-
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
52+
# To be deprecated soon
53+
from ..models.lora import PatchedLoraProjection
5354

5455
if is_accelerate_available():
5556
from accelerate import init_empty_weights
@@ -66,6 +67,34 @@
6667
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
6768

6869

70+
def text_encoder_attn_modules(text_encoder):
71+
attn_modules = []
72+
73+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
74+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
75+
name = f"text_model.encoder.layers.{i}.self_attn"
76+
mod = layer.self_attn
77+
attn_modules.append((name, mod))
78+
else:
79+
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
80+
81+
return attn_modules
82+
83+
84+
def text_encoder_mlp_modules(text_encoder):
85+
mlp_modules = []
86+
87+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
88+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
89+
mlp_mod = layer.mlp
90+
name = f"text_model.encoder.layers.{i}.mlp"
91+
mlp_modules.append((name, mlp_mod))
92+
else:
93+
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
94+
95+
return mlp_modules
96+
97+
6998
class LoraLoaderMixin:
7099
r"""
71100
Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`].
@@ -1415,7 +1444,7 @@ def process_weights(adapter_names, weights):
14151444
)
14161445
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
14171446

1418-
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
1447+
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
14191448
"""
14201449
Disable the text encoder's LoRA layers.
14211450
@@ -1445,7 +1474,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"
14451474
raise ValueError("Text Encoder not found.")
14461475
set_adapter_layers(text_encoder, enabled=False)
14471476

1448-
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
1477+
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
14491478
"""
14501479
Enables the text encoder's LoRA layers.
14511480

0 commit comments

Comments
 (0)