-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[LoRA] pop the LoRA scale so that it doesn't get propagated to the weeds #7338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
460773a
pop scale from the top-level unet instead of getting it.
sayakpaul 663be37
improve readability.
sayakpaul efc03ab
Merge branch 'main' into pop-scale
sayakpaul 52a60e0
Merge branch 'main' into pop-scale
sayakpaul 86f56c1
Merge branch 'main' into pop-scale
sayakpaul 1259947
Apply suggestions from code review
sayakpaul d4daf17
fix a little bit.
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1081,25 +1081,15 @@ def forward( | |
| A tuple of tensors that if specified are added to the residuals of down unet blocks. | ||
| mid_block_additional_residual: (`torch.Tensor`, *optional*): | ||
| A tensor that if specified is added to the residual of the middle unet block. | ||
| down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | ||
| additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | ||
| encoder_attention_mask (`torch.Tensor`): | ||
| A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | ||
| `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | ||
| which adds large negative values to the attention scores corresponding to "discard" tokens. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain | ||
| tuple. | ||
| cross_attention_kwargs (`dict`, *optional*): | ||
| A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. | ||
| added_cond_kwargs: (`dict`, *optional*): | ||
| A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that | ||
| are passed along to the UNet blocks. | ||
| down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | ||
| additional residuals to be added to UNet long skip connections from down blocks to up blocks for | ||
| example from ControlNet side model(s) | ||
| mid_block_additional_residual (`torch.Tensor`, *optional*): | ||
| additional residual to be added to UNet mid block output, for example from ControlNet side model | ||
| down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | ||
| additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | ||
|
|
||
| Returns: | ||
| [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: | ||
|
|
@@ -1185,7 +1175,13 @@ def forward( | |
| cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} | ||
|
|
||
| # 3. down | ||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | ||
| # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | ||
| # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | ||
| if cross_attention_kwargs is not None: | ||
| lora_scale = cross_attention_kwargs.pop("scale", 1.0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that is correct. |
||
| else: | ||
| lora_scale = 1.0 | ||
|
|
||
| if USE_PEFT_BACKEND: | ||
| # weight the lora layers by setting `lora_scale` for each PEFT layer | ||
| scale_lora_layers(self, lora_scale) | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicates.