diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e0404a83cc9a..0bc7886c2653 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -78,6 +78,7 @@ def __init__( self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection + self.dropout = dropout # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f6d6bc5711cd..9e9b5cde4675 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -22,7 +22,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch -from torch import Tensor, device +from torch import Tensor, device, nn from .. import __version__ from ..utils import ( @@ -646,15 +646,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch( - model, - model_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - ) + try: + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warn( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e loading_info = { "missing_keys": [], @@ -889,3 +921,53 @@ def recursive_find_attn_block(name, module): state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + def _temp_convert_self_to_deprecated_attention_blocks(self): + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.query = module.to_q + module.key = module.to_k + module.value = module.to_v + module.proj_attn = module.to_out[0] + + # We don't _have_ to delete the old attributes, but it's helpful to ensure + # that _all_ the weights are loaded into the new attributes and we're not + # making an incorrect assumption that this model should be converted when + # it really shouldn't be. + del module.to_q + del module.to_k + del module.to_v + del module.to_out + + def _undo_temp_convert_self_to_deprecated_attention_blocks(self): + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.to_q = module.query + module.to_k = module.key + module.to_v = module.value + module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) + + del module.query + del module.key + del module.value + del module.proj_attn diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py index 172d6d4d91fc..f9b5924ca5e0 100644 --- a/tests/models/test_attention_processor.py +++ b/tests/models/test_attention_processor.py @@ -1,7 +1,10 @@ +import tempfile import unittest +import numpy as np import torch +from diffusers import DiffusionPipeline from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor @@ -73,3 +76,44 @@ def test_only_cross_attention(self): only_cross_attn_out = attn(**forward_args) self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) + + +class DeprecatedAttentionBlockTests(unittest.TestCase): + def test_conversion_when_using_device_map(self): + pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None) + + pre_conversion = pipe( + "foo", + num_inference_steps=2, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + # the initial conversion succeeds + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None + ) + + conversion = pipe( + "foo", + num_inference_steps=2, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + with tempfile.TemporaryDirectory() as tmpdir: + # save the converted model + pipe.save_pretrained(tmpdir) + + # can also load the converted weights + pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None) + + after_conversion = pipe( + "foo", + num_inference_steps=2, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + self.assertTrue(np.allclose(pre_conversion, conversion)) + self.assertTrue(np.allclose(conversion, after_conversion))