Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2e54b93
configurable layers
linoytsaban Oct 24, 2024
8fb8fc0
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 25, 2024
df919b8
configurable layers
linoytsaban Oct 25, 2024
dfd8897
update README
linoytsaban Oct 25, 2024
75c12a9
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 25, 2024
62e152a
style
linoytsaban Oct 25, 2024
e285d69
add test
linoytsaban Oct 25, 2024
f886565
Merge remote-tracking branch 'origin/sd-3-5-explorations' into sd-3-5…
linoytsaban Oct 25, 2024
f073014
style
linoytsaban Oct 25, 2024
701bd35
add layer test, update readme, add nargs
linoytsaban Oct 25, 2024
90550a8
readme
linoytsaban Oct 25, 2024
2cba0c9
test style
linoytsaban Oct 25, 2024
128826b
remove print, change nargs
linoytsaban Oct 25, 2024
01995ed
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 25, 2024
6f8e392
Merge remote-tracking branch 'origin/sd-3-5-explorations' into sd-3-5…
linoytsaban Oct 25, 2024
0c7fa8b
test arg change
linoytsaban Oct 28, 2024
d96db50
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 28, 2024
10a2659
style
linoytsaban Oct 28, 2024
7b087d2
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 28, 2024
ad6c2f3
revert nargs 2/2
linoytsaban Oct 28, 2024
aebfa03
Merge remote-tracking branch 'origin/sd-3-5-explorations' into sd-3-5…
linoytsaban Oct 28, 2024
3623a6b
Merge branch 'main' into sd-3-5-explorations
linoytsaban Oct 28, 2024
df018dd
address sayaks comments
linoytsaban Oct 28, 2024
65dd59d
style
linoytsaban Oct 28, 2024
559b0bc
address sayaks comments
linoytsaban Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```

### Targeting Specific Blocks & Layers
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
transformer blocks (sometimes as little as two) can be enough to get great results.
In some cases, it can be even better to maintain some of the blocks/layers frozen.

For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
> [!NOTE]
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
> So, freezing other layers/targeting specific layers is a viable approach.
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
> **Photorealism**
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
> **Anatomy preservation**
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes total sense to me!



We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
```diff
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
```
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
- with `--lora_layers` you can specify the types of layers you wish to train.
By default, the trained layers are -
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaner LoRA targetting what aspect? From what I understand, this heuristic is for targeting a specific quality, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the feature was added to allow experimentation with what layers produce the best quality, but since the default (once the pr is merged) will be
attn.add_k_proj attn.add_q_proj attn.add_v_proj attn.to_add_out attn.to_k attn.to_out.0 attn.to_q attn.to_v
which makes every trained block chunkier than the previous default, I wanted to also give as an example the previous setting we had which is
--lora_layers attn.to_k attn.to_q attn.to_v attn.to_out.0
that will result in a smaller lora.
But if it's confusing/unclear I can remove that

```diff
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
```
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
freezing some of the early & blocks is usually better.


### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
Expand Down
71 changes: 71 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"

transformer_block_idx = 0
layer_type = "attn.to_k"

def test_dreambooth_lora_sd3(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down Expand Up @@ -136,6 +139,74 @@ def test_dreambooth_lora_latent_caching(self):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_block(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_blocks {self.transformer_block_idx}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
# In this test, only params of transformer block 0 should be in the state dict
starts_with_transformer = all(
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
)
Comment on lines +170 to +175
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super!

self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_layer(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_layers {self.layer_type}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
39 changes: 38 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,25 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)

parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
),
)
parser.add_argument(
"--lora_blocks",
type=str,
default=None,
help=(
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
),
)

parser.add_argument(
"--adam_epsilon",
type=float,
Expand Down Expand Up @@ -1222,13 +1241,31 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_out.0",
"attn.to_q",
"attn.to_v",
]
if args.lora_blocks is not None:
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
target_modules = [
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
]

# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)

Expand Down
Loading