Skip to content
Merged
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
0105fc4
update
DN6 Dec 19, 2023
2afad15
Merge branch 'main' into refactor-single-file
DN6 Dec 22, 2023
2686fdd
update
DN6 Dec 22, 2023
ef656d7
update
DN6 Dec 22, 2023
daf4d05
update
DN6 Dec 25, 2023
8b7eecd
update
DN6 Dec 26, 2023
0cd1be4
update
DN6 Dec 26, 2023
16a80d3
update
DN6 Dec 26, 2023
0b24f88
Merge branch 'main' into refactor-single-file
DN6 Dec 27, 2023
7289be1
update
DN6 Dec 28, 2023
0012dd2
update
DN6 Dec 28, 2023
2616e03
update
DN6 Dec 28, 2023
7db4f50
update'
DN6 Dec 28, 2023
872aa6c
update
DN6 Dec 28, 2023
83c5b8e
update
DN6 Dec 29, 2023
5a8e10e
update
DN6 Dec 29, 2023
7a8c722
update
DN6 Dec 29, 2023
ccf8d62
update
DN6 Dec 29, 2023
da9c9d5
update
DN6 Dec 29, 2023
b791a71
up
DN6 Dec 29, 2023
c6c8fc7
update
DN6 Dec 29, 2023
6ba7a50
update
DN6 Dec 29, 2023
b44d2b4
update
DN6 Dec 30, 2023
41e97e0
update
DN6 Dec 30, 2023
658d80f
update
DN6 Dec 30, 2023
5daf61a
update
DN6 Dec 30, 2023
af6cd36
update
DN6 Dec 30, 2023
6d743ef
update
DN6 Dec 30, 2023
b7732a0
update
DN6 Dec 30, 2023
9d10d2d
update
DN6 Dec 30, 2023
820313b
update
DN6 Dec 30, 2023
efc6380
update
DN6 Dec 30, 2023
9453626
up
DN6 Dec 30, 2023
afa62e6
update
DN6 Dec 30, 2023
e033f9f
update
DN6 Dec 30, 2023
c0d62ac
update
DN6 Dec 30, 2023
9605db5
update
DN6 Dec 30, 2023
e945e18
update'
DN6 Dec 30, 2023
fa3a0d6
update
DN6 Jan 2, 2024
bbc60be
update
DN6 Jan 2, 2024
b69cddb
update
DN6 Jan 2, 2024
3ae0b83
update
DN6 Jan 2, 2024
6c19f0a
update
DN6 Jan 2, 2024
ba704fd
update
DN6 Jan 2, 2024
f304528
update
DN6 Jan 2, 2024
3c806be
update
DN6 Jan 2, 2024
f86ba55
update
DN6 Jan 2, 2024
cf2fe1e
Merge branch 'main' into refactor-single-file
DN6 Jan 12, 2024
cf560a7
update
DN6 Jan 15, 2024
0ec1ed7
update
DN6 Jan 16, 2024
4bb4ed4
update
DN6 Jan 16, 2024
68a49b1
update
DN6 Jan 16, 2024
e37abaf
update
DN6 Jan 17, 2024
1bd8ba3
update
DN6 Jan 17, 2024
1cce591
update
DN6 Jan 17, 2024
df4a8ea
update
DN6 Jan 17, 2024
249f78e
update
DN6 Jan 17, 2024
8a24733
update
DN6 Jan 17, 2024
de77ff6
update
DN6 Jan 18, 2024
0939565
update
DN6 Jan 18, 2024
c22c2aa
update
DN6 Jan 18, 2024
eb71c80
update
DN6 Jan 18, 2024
32349c5
update
DN6 Jan 18, 2024
a076513
update
DN6 Jan 18, 2024
db3eb06
update
DN6 Jan 19, 2024
9b42fbf
update
DN6 Jan 19, 2024
1ca79f7
update
DN6 Jan 19, 2024
ffde123
update
DN6 Jan 19, 2024
fd2ec36
update
DN6 Jan 19, 2024
aee8b5f
update
DN6 Jan 19, 2024
2fb9baf
update
DN6 Jan 19, 2024
bb8d317
clean
DN6 Jan 19, 2024
480a4b4
update
DN6 Jan 19, 2024
2483d51
update
DN6 Jan 19, 2024
dab7f01
clean up
DN6 Jan 19, 2024
68ddb25
clean up
DN6 Jan 19, 2024
7395283
update
DN6 Jan 19, 2024
153e746
clean
DN6 Jan 19, 2024
a371c3b
clean
DN6 Jan 19, 2024
ba66fb8
update
DN6 Jan 19, 2024
b658618
updaet
DN6 Jan 19, 2024
3620357
clean up
DN6 Jan 19, 2024
dae09d0
fix docs
DN6 Jan 19, 2024
0746cf9
update
DN6 Jan 22, 2024
dbfb8f1
update
DN6 Jan 22, 2024
82ce94e
Revert "update"
DN6 Jan 22, 2024
6f8446a
update
DN6 Jan 22, 2024
e1d82e2
Merge branch 'main' into refactor-single-file
DN6 Jan 22, 2024
b2c9561
update
DN6 Jan 22, 2024
e297ac8
update
DN6 Jan 23, 2024
d1e3466
update
DN6 Jan 23, 2024
650a632
fix controlnet
DN6 Jan 23, 2024
99fdba9
fix scheduler
DN6 Jan 23, 2024
8c9af6c
fix controlnet tests
DN6 Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/api/loaders/single_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta

## FromOriginalVAEMixin

[[autodoc]] loaders.single_file.FromOriginalVAEMixin
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin

## FromOriginalControlnetMixin

[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
10 changes: 6 additions & 4 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def text_encoder_attn_modules(text_encoder):
_import_structure = {}

if is_torch_available():
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]

_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]

if is_transformers_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
Expand All @@ -69,7 +70,8 @@ def text_encoder_attn_modules(text_encoder):

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

Expand Down
126 changes: 126 additions & 0 deletions src/diffusers/loaders/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from huggingface_hub.utils import validate_hf_hub_args

from .single_file_utils import (
create_diffusers_vae_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)


class FromOriginalVAEMixin:
"""
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.

Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is .ckpt still used? If we know that to be "yes" I think it would make sense to add a comment below that ".ckpt" files are also supported.

- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow passing "auto" here?

dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.

<Tip warning={true}>

Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.

</Tip>

Examples:

```py
from diffusers import AutoencoderKL

url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""

original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", True)

class_name = cls.__name__
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
)

image_size = kwargs.pop("image_size", None)
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
vae = component["vae"]
if torch_dtype is not None:
vae = vae.to(torch_dtype)

return vae
127 changes: 127 additions & 0 deletions src/diffusers/loaders/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from huggingface_hub.utils import validate_hf_hub_args

from .single_file_utils import (
create_diffusers_controlnet_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)


class FromOriginalControlNetMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.

Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.

Examples:

```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)

url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
Comment on lines +81 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For housekeeping, let's make sure we have these examples (working examples, of course) for all the single file loader mixins.

```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", True)

class_name = cls.__name__
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
)

upcast_attention = kwargs.pop("upcast_attention", False)
image_size = kwargs.pop("image_size", None)

component = create_diffusers_controlnet_model_from_ldm(
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
)
controlnet = component["controlnet"]
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)

return controlnet
Loading