-
Notifications
You must be signed in to change notification settings - Fork 6.5k
device map legacy attention block weight conversion #3804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| from typing import Any, Callable, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import Tensor, device | ||
| from torch import Tensor, device, nn | ||
|
|
||
| from .. import __version__ | ||
| from ..utils import ( | ||
|
|
@@ -646,15 +646,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| else: # else let accelerate handle loading and dispatching. | ||
| # Load weights and dispatch according to the device_map | ||
| # by default the device_map is None and the weights are loaded on the CPU | ||
| accelerate.load_checkpoint_and_dispatch( | ||
| model, | ||
| model_file, | ||
| device_map, | ||
| max_memory=max_memory, | ||
| offload_folder=offload_folder, | ||
| offload_state_dict=offload_state_dict, | ||
| dtype=torch_dtype, | ||
| ) | ||
| try: | ||
| accelerate.load_checkpoint_and_dispatch( | ||
| model, | ||
| model_file, | ||
| device_map, | ||
| max_memory=max_memory, | ||
| offload_folder=offload_folder, | ||
| offload_state_dict=offload_state_dict, | ||
| dtype=torch_dtype, | ||
| ) | ||
| except AttributeError as e: | ||
| # When using accelerate loading, we do not have the ability to load the state | ||
| # dict and rename the weight names manually. Additionally, accelerate skips | ||
| # torch loading conventions and directly writes into `module.{_buffers, _parameters}` | ||
| # (which look like they should be private variables?), so we can't use the standard hooks | ||
| # to rename parameters on load. We need to mimic the original weight names so the correct | ||
| # attributes are available. After we have loaded the weights, we convert the deprecated | ||
| # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have guidance available for the users on how they should perform the conversion?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah nevermind. I guess you meant once we load the old attention block weight names and run conversion internally, and suggest users save the pipeline. Right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah just this blurb here |
||
| # the weights so we don't have to do this again. | ||
|
|
||
| if "'Attention' object has no attribute" in str(e): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pretty hacky, but OK! Let's leave it for now :-)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I cringed while writing this 🙃 |
||
| logger.warn( | ||
| f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" | ||
| " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" | ||
| " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," | ||
| " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," | ||
| " please also re-upload it or open a PR on the original repository." | ||
williamberman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| model._temp_convert_self_to_deprecated_attention_blocks() | ||
| accelerate.load_checkpoint_and_dispatch( | ||
| model, | ||
| model_file, | ||
| device_map, | ||
| max_memory=max_memory, | ||
| offload_folder=offload_folder, | ||
| offload_state_dict=offload_state_dict, | ||
| dtype=torch_dtype, | ||
| ) | ||
| model._undo_temp_convert_self_to_deprecated_attention_blocks() | ||
| else: | ||
| raise e | ||
|
|
||
| loading_info = { | ||
| "missing_keys": [], | ||
|
|
@@ -889,3 +921,53 @@ def recursive_find_attn_block(name, module): | |
| state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") | ||
| if f"{path}.proj_attn.bias" in state_dict: | ||
| state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") | ||
|
|
||
| def _temp_convert_self_to_deprecated_attention_blocks(self): | ||
williamberman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| deprecated_attention_block_modules = [] | ||
|
|
||
| def recursive_find_attn_block(module): | ||
| if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: | ||
| deprecated_attention_block_modules.append(module) | ||
|
|
||
| for sub_module in module.children(): | ||
| recursive_find_attn_block(sub_module) | ||
|
|
||
| recursive_find_attn_block(self) | ||
|
|
||
| for module in deprecated_attention_block_modules: | ||
| module.query = module.to_q | ||
| module.key = module.to_k | ||
| module.value = module.to_v | ||
| module.proj_attn = module.to_out[0] | ||
|
|
||
| # We don't _have_ to delete the old attributes, but it's helpful to ensure | ||
| # that _all_ the weights are loaded into the new attributes and we're not | ||
| # making an incorrect assumption that this model should be converted when | ||
| # it really shouldn't be. | ||
| del module.to_q | ||
| del module.to_k | ||
| del module.to_v | ||
| del module.to_out | ||
|
|
||
| def _undo_temp_convert_self_to_deprecated_attention_blocks(self): | ||
williamberman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| deprecated_attention_block_modules = [] | ||
|
|
||
| def recursive_find_attn_block(module): | ||
| if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: | ||
| deprecated_attention_block_modules.append(module) | ||
|
|
||
| for sub_module in module.children(): | ||
| recursive_find_attn_block(sub_module) | ||
|
|
||
| recursive_find_attn_block(self) | ||
|
|
||
| for module in deprecated_attention_block_modules: | ||
| module.to_q = module.query | ||
| module.to_k = module.key | ||
| module.to_v = module.value | ||
| module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) | ||
|
|
||
| del module.query | ||
| del module.key | ||
| del module.value | ||
| del module.proj_attn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,10 @@ | ||
| import tempfile | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from diffusers import DiffusionPipeline | ||
| from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor | ||
|
|
||
|
|
||
|
|
@@ -73,3 +76,44 @@ def test_only_cross_attention(self): | |
| only_cross_attn_out = attn(**forward_args) | ||
|
|
||
| self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) | ||
|
|
||
|
|
||
| class DeprecatedAttentionBlockTests(unittest.TestCase): | ||
| def test_conversion_when_using_device_map(self): | ||
| pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None) | ||
|
|
||
| pre_conversion = pipe( | ||
| "foo", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Killer prompt. |
||
| num_inference_steps=2, | ||
| generator=torch.Generator("cpu").manual_seed(0), | ||
| output_type="np", | ||
| ).images | ||
|
|
||
| # the initial conversion succeeds | ||
| pipe = DiffusionPipeline.from_pretrained( | ||
| "hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None | ||
| ) | ||
|
|
||
| conversion = pipe( | ||
| "foo", | ||
| num_inference_steps=2, | ||
| generator=torch.Generator("cpu").manual_seed(0), | ||
| output_type="np", | ||
| ).images | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| # save the converted model | ||
| pipe.save_pretrained(tmpdir) | ||
|
|
||
| # can also load the converted weights | ||
| pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None) | ||
|
|
||
| after_conversion = pipe( | ||
| "foo", | ||
| num_inference_steps=2, | ||
| generator=torch.Generator("cpu").manual_seed(0), | ||
| output_type="np", | ||
| ).images | ||
|
|
||
| self.assertTrue(np.allclose(pre_conversion, conversion)) | ||
| self.assertTrue(np.allclose(conversion, after_conversion)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added so can re-create the dropout module when converting back to new weight format