Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added so can re-create the dropout module when converting back to new weight format


# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand Down
102 changes: 92 additions & 10 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor, device
from torch import Tensor, device, nn

from .. import __version__
from ..utils import (
Expand Down Expand Up @@ -646,15 +646,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have guidance available for the users on how they should perform the conversion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nevermind. I guess you meant once we load the old attention block weight names and run conversion internally, and suggest users save the pipeline. Right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah just this blurb here

  f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
                                " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
                                " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
                                " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
                                " please also re-upload it or open a PR on the original repository."

# the weights so we don't have to do this again.

if "'Attention' object has no attribute" in str(e):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty hacky, but OK! Let's leave it for now :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I cringed while writing this 🙃

logger.warn(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e

loading_info = {
"missing_keys": [],
Expand Down Expand Up @@ -889,3 +921,53 @@ def recursive_find_attn_block(name, module):
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")

def _temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []

def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)

for sub_module in module.children():
recursive_find_attn_block(sub_module)

recursive_find_attn_block(self)

for module in deprecated_attention_block_modules:
module.query = module.to_q
module.key = module.to_k
module.value = module.to_v
module.proj_attn = module.to_out[0]

# We don't _have_ to delete the old attributes, but it's helpful to ensure
# that _all_ the weights are loaded into the new attributes and we're not
# making an incorrect assumption that this model should be converted when
# it really shouldn't be.
del module.to_q
del module.to_k
del module.to_v
del module.to_out

def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []

def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)

for sub_module in module.children():
recursive_find_attn_block(sub_module)

recursive_find_attn_block(self)

for module in deprecated_attention_block_modules:
module.to_q = module.query
module.to_k = module.key
module.to_v = module.value
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])

del module.query
del module.key
del module.value
del module.proj_attn
44 changes: 44 additions & 0 deletions tests/models/test_attention_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import tempfile
import unittest

import numpy as np
import torch

from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor


Expand Down Expand Up @@ -73,3 +76,44 @@ def test_only_cross_attention(self):
only_cross_attn_out = attn(**forward_args)

self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())


class DeprecatedAttentionBlockTests(unittest.TestCase):
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None)

pre_conversion = pipe(
"foo",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Killer prompt.

num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images

# the initial conversion succeeds
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None
)

conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images

with tempfile.TemporaryDirectory() as tmpdir:
# save the converted model
pipe.save_pretrained(tmpdir)

# can also load the converted weights
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None)

after_conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images

self.assertTrue(np.allclose(pre_conversion, conversion))
self.assertTrue(np.allclose(conversion, after_conversion))