-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[SD 3.5 Dreambooth LoRA] support configurable training block & layers #9762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
2e54b93
8fb8fc0
df919b8
dfd8897
75c12a9
62e152a
e285d69
f886565
f073014
701bd35
90550a8
2cba0c9
128826b
01995ed
6f8e392
0c7fa8b
d96db50
10a2659
7b087d2
ad6c2f3
aebfa03
3623a6b
df018dd
65dd59d
559b0bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \ | |
| --push_to_hub | ||
| ``` | ||
|
|
||
| ### Targeting Specific Blocks & Layers | ||
| As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the | ||
| transformer blocks (sometimes as little as two) can be enough to get great results. | ||
| In some cases, it can be even better to maintain some of the blocks/layers frozen. | ||
|
|
||
| For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93): | ||
| > [!NOTE] | ||
| > A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more. | ||
| > So, freezing other layers/targeting specific layers is a viable approach. | ||
| > `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps. | ||
| > **Photorealism** | ||
| > In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening. | ||
| > **Anatomy preservation** | ||
| > To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes total sense to me! |
||
|
|
||
|
|
||
| We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable. | ||
| - with `--lora_blocks` you can specify the block numbers for training. E.g. passing - | ||
| ```diff | ||
| --lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37" | ||
| ``` | ||
| will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained. | ||
| - with `--lora_layers` you can specify the types of layers you wish to train. | ||
| By default, the trained layers are - | ||
| `attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v` | ||
| If you wish to have a leaner LoRA / train more blocks over layers you could pass - | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaner LoRA targetting what aspect? From what I understand, this heuristic is for targeting a specific quality, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed the feature was added to allow experimentation with what layers produce the best quality, but since the default (once the pr is merged) will be |
||
| ```diff | ||
| --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0 | ||
|
||
| ``` | ||
| This will reduce LoRA size by roughly 50% for the same rank compared to the default. | ||
| However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and | ||
| freezing some of the early & blocks is usually better. | ||
|
|
||
|
|
||
| ### Text Encoder Training | ||
| Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported. | ||
| To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,6 +136,76 @@ def test_dreambooth_lora_latent_caching(self): | |
| starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) | ||
| self.assertTrue(starts_with_transformer) | ||
|
|
||
| def test_dreambooth_lora_block(self): | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| test_args = f""" | ||
| {self.script_path} | ||
| --pretrained_model_name_or_path {self.pretrained_model_name_or_path} | ||
| --instance_data_dir {self.instance_data_dir} | ||
| --instance_prompt {self.instance_prompt} | ||
| --resolution 64 | ||
| --train_batch_size 1 | ||
| --gradient_accumulation_steps 1 | ||
| --max_train_steps 2 | ||
| --lora_blocks 0 | ||
|
||
| --learning_rate 5.0e-04 | ||
| --scale_lr | ||
| --lr_scheduler constant | ||
| --lr_warmup_steps 0 | ||
| --output_dir {tmpdir} | ||
| """.split() | ||
|
|
||
| run_command(self._launch_args + test_args) | ||
| # save_pretrained smoke test | ||
| self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
|
||
| # make sure the state_dict has the correct naming in the parameters. | ||
| lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
| is_lora = all("lora" in k for k in lora_state_dict.keys()) | ||
| self.assertTrue(is_lora) | ||
|
|
||
| # when not training the text encoder, all the parameters in the state dict should start | ||
| # with `"transformer"` in their names. | ||
| # In this test, only params of transformer block 0 should be in the state dict | ||
| starts_with_transformer = all( | ||
| key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys() | ||
| ) | ||
|
Comment on lines
+170
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super! |
||
| self.assertTrue(starts_with_transformer) | ||
|
|
||
| def test_dreambooth_lora_layer(self): | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| test_args = f""" | ||
| {self.script_path} | ||
| --pretrained_model_name_or_path {self.pretrained_model_name_or_path} | ||
| --instance_data_dir {self.instance_data_dir} | ||
| --instance_prompt {self.instance_prompt} | ||
| --resolution 64 | ||
| --train_batch_size 1 | ||
| --gradient_accumulation_steps 1 | ||
| --max_train_steps 2 | ||
| --lora_layers attn.to_k | ||
|
||
| --learning_rate 5.0e-04 | ||
| --scale_lr | ||
| --lr_scheduler constant | ||
| --lr_warmup_steps 0 | ||
| --output_dir {tmpdir} | ||
| """.split() | ||
|
|
||
| run_command(self._launch_args + test_args) | ||
| # save_pretrained smoke test | ||
| self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
|
||
| # make sure the state_dict has the correct naming in the parameters. | ||
| lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
| is_lora = all("lora" in k for k in lora_state_dict.keys()) | ||
| self.assertTrue(is_lora) | ||
|
|
||
| # when not training the text encoder, all the parameters in the state dict should start | ||
| # with `"transformer"` in their names. | ||
| # In this test, only params of transformer block 0 should be in the state dict | ||
|
||
| starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys()) | ||
| self.assertTrue(starts_with_transformer) | ||
|
|
||
| def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| test_args = f""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.