-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Refactor] Update from single file #6428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 83 commits
0105fc4
2afad15
2686fdd
ef656d7
daf4d05
8b7eecd
0cd1be4
16a80d3
0b24f88
7289be1
0012dd2
2616e03
7db4f50
872aa6c
83c5b8e
5a8e10e
7a8c722
ccf8d62
da9c9d5
b791a71
c6c8fc7
6ba7a50
b44d2b4
41e97e0
658d80f
5daf61a
af6cd36
6d743ef
b7732a0
9d10d2d
820313b
efc6380
9453626
afa62e6
e033f9f
c0d62ac
9605db5
e945e18
fa3a0d6
bbc60be
b69cddb
3ae0b83
6c19f0a
ba704fd
f304528
3c806be
f86ba55
cf2fe1e
cf560a7
0ec1ed7
4bb4ed4
68a49b1
e37abaf
1bd8ba3
1cce591
df4a8ea
249f78e
8a24733
de77ff6
0939565
c22c2aa
eb71c80
32349c5
a076513
db3eb06
9b42fbf
1ca79f7
ffde123
fd2ec36
aee8b5f
2fb9baf
bb8d317
480a4b4
2483d51
dab7f01
68ddb25
7395283
153e746
a371c3b
ba66fb8
b658618
3620357
dae09d0
0746cf9
dbfb8f1
82ce94e
6f8446a
e1d82e2
b2c9561
e297ac8
d1e3466
650a632
99fdba9
8c9af6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from huggingface_hub.utils import validate_hf_hub_args | ||
|
|
||
| from .single_file_utils import ( | ||
| create_diffusers_vae_model_from_ldm, | ||
| fetch_ldm_config_and_checkpoint, | ||
| ) | ||
|
|
||
|
|
||
| class FromOriginalVAEMixin: | ||
| """ | ||
| Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. | ||
| """ | ||
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| def from_single_file(cls, pretrained_model_link_or_path, **kwargs): | ||
| r""" | ||
| Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or | ||
| `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. | ||
|
|
||
| Parameters: | ||
| pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): | ||
| Can be either: | ||
| - A link to the `.ckpt` file (for example | ||
| `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub. | ||
| - A path to a *file* containing all pipeline weights. | ||
| torch_dtype (`str` or `torch.dtype`, *optional*): | ||
| Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We allow passing |
||
| dtype is automatically derived from the model's weights. | ||
| force_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
| cached versions if they exist. | ||
| cache_dir (`Union[str, os.PathLike]`, *optional*): | ||
| Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | ||
| is not used. | ||
| resume_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | ||
| incompletely downloaded files are deleted. | ||
| proxies (`Dict[str, str]`, *optional*): | ||
| A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | ||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| Whether to only load local model weights and configuration files or not. If set to True, the model | ||
| won't be downloaded from the Hub. | ||
| token (`str` or *bool*, *optional*): | ||
| The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | ||
| `diffusers-cli login` (stored in `~/.huggingface`) is used. | ||
| revision (`str`, *optional*, defaults to `"main"`): | ||
| The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | ||
| allowed by Git. | ||
| image_size (`int`, *optional*, defaults to 512): | ||
| The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable | ||
| Diffusion v2 base model. Use 768 for Stable Diffusion v2. | ||
| use_safetensors (`bool`, *optional*, defaults to `None`): | ||
| If set to `None`, the safetensors weights are downloaded if they're available **and** if the | ||
| safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors | ||
| weights. If set to `False`, safetensors weights are not loaded. | ||
| kwargs (remaining dictionary of keyword arguments, *optional*): | ||
| Can be used to overwrite load and saveable variables (for example the pipeline components of the | ||
| specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` | ||
| method. See example below for more information. | ||
|
|
||
| <Tip warning={true}> | ||
|
|
||
| Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading | ||
| a VAE from SDXL or a Stable Diffusion v2 model or higher. | ||
|
|
||
| </Tip> | ||
|
|
||
| Examples: | ||
|
|
||
| ```py | ||
| from diffusers import AutoencoderKL | ||
|
|
||
| url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file | ||
| model = AutoencoderKL.from_single_file(url) | ||
| ``` | ||
| """ | ||
|
|
||
| original_config_file = kwargs.pop("original_config_file", None) | ||
| resume_download = kwargs.pop("resume_download", False) | ||
| force_download = kwargs.pop("force_download", False) | ||
| proxies = kwargs.pop("proxies", None) | ||
| token = kwargs.pop("token", None) | ||
| cache_dir = kwargs.pop("cache_dir", None) | ||
| local_files_only = kwargs.pop("local_files_only", None) | ||
| revision = kwargs.pop("revision", None) | ||
| torch_dtype = kwargs.pop("torch_dtype", None) | ||
| use_safetensors = kwargs.pop("use_safetensors", True) | ||
|
|
||
| class_name = cls.__name__ | ||
| original_config, checkpoint = fetch_ldm_config_and_checkpoint( | ||
| pretrained_model_link_or_path=pretrained_model_link_or_path, | ||
| class_name=class_name, | ||
| original_config_file=original_config_file, | ||
| resume_download=resume_download, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| token=token, | ||
| revision=revision, | ||
| local_files_only=local_files_only, | ||
| use_safetensors=use_safetensors, | ||
| cache_dir=cache_dir, | ||
| ) | ||
|
|
||
| image_size = kwargs.pop("image_size", None) | ||
| component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size) | ||
| vae = component["vae"] | ||
| if torch_dtype is not None: | ||
| vae = vae.to(torch_dtype) | ||
|
|
||
| return vae | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from huggingface_hub.utils import validate_hf_hub_args | ||
|
|
||
| from .single_file_utils import ( | ||
| create_diffusers_controlnet_model_from_ldm, | ||
| fetch_ldm_config_and_checkpoint, | ||
| ) | ||
|
|
||
|
|
||
| class FromOriginalControlNetMixin: | ||
| """ | ||
| Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. | ||
| """ | ||
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| def from_single_file(cls, pretrained_model_link_or_path, **kwargs): | ||
| r""" | ||
| Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or | ||
| `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. | ||
|
|
||
| Parameters: | ||
| pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): | ||
| Can be either: | ||
| - A link to the `.ckpt` file (for example | ||
| `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub. | ||
| - A path to a *file* containing all pipeline weights. | ||
| torch_dtype (`str` or `torch.dtype`, *optional*): | ||
| Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the | ||
| dtype is automatically derived from the model's weights. | ||
| force_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
| cached versions if they exist. | ||
| cache_dir (`Union[str, os.PathLike]`, *optional*): | ||
| Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | ||
| is not used. | ||
| resume_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | ||
| incompletely downloaded files are deleted. | ||
| proxies (`Dict[str, str]`, *optional*): | ||
| A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | ||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| Whether to only load local model weights and configuration files or not. If set to True, the model | ||
| won't be downloaded from the Hub. | ||
| token (`str` or *bool*, *optional*): | ||
| The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | ||
| `diffusers-cli login` (stored in `~/.huggingface`) is used. | ||
| revision (`str`, *optional*, defaults to `"main"`): | ||
| The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | ||
| allowed by Git. | ||
| use_safetensors (`bool`, *optional*, defaults to `None`): | ||
| If set to `None`, the safetensors weights are downloaded if they're available **and** if the | ||
| safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors | ||
| weights. If set to `False`, safetensors weights are not loaded. | ||
| image_size (`int`, *optional*, defaults to 512): | ||
| The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable | ||
| Diffusion v2 base model. Use 768 for Stable Diffusion v2. | ||
| upcast_attention (`bool`, *optional*, defaults to `None`): | ||
| Whether the attention computation should always be upcasted. | ||
| kwargs (remaining dictionary of keyword arguments, *optional*): | ||
| Can be used to overwrite load and saveable variables (for example the pipeline components of the | ||
| specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` | ||
| method. See example below for more information. | ||
|
|
||
| Examples: | ||
|
|
||
| ```py | ||
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | ||
|
|
||
| url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path | ||
| model = ControlNetModel.from_single_file(url) | ||
|
|
||
| url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path | ||
| pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) | ||
|
Comment on lines
+81
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For housekeeping, let's make sure we have these examples (working examples, of course) for all the single file loader mixins. |
||
| ``` | ||
| """ | ||
| original_config_file = kwargs.pop("original_config_file", None) | ||
| resume_download = kwargs.pop("resume_download", False) | ||
| force_download = kwargs.pop("force_download", False) | ||
| proxies = kwargs.pop("proxies", None) | ||
| token = kwargs.pop("token", None) | ||
| cache_dir = kwargs.pop("cache_dir", None) | ||
| local_files_only = kwargs.pop("local_files_only", None) | ||
| revision = kwargs.pop("revision", None) | ||
| torch_dtype = kwargs.pop("torch_dtype", None) | ||
| use_safetensors = kwargs.pop("use_safetensors", True) | ||
|
|
||
| class_name = cls.__name__ | ||
| original_config, checkpoint = fetch_ldm_config_and_checkpoint( | ||
| pretrained_model_link_or_path=pretrained_model_link_or_path, | ||
| class_name=class_name, | ||
| original_config_file=original_config_file, | ||
| resume_download=resume_download, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| token=token, | ||
| revision=revision, | ||
| local_files_only=local_files_only, | ||
| use_safetensors=use_safetensors, | ||
| cache_dir=cache_dir, | ||
| ) | ||
|
|
||
| upcast_attention = kwargs.pop("upcast_attention", False) | ||
| image_size = kwargs.pop("image_size", None) | ||
|
|
||
| component = create_diffusers_controlnet_model_from_ldm( | ||
| class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size | ||
| ) | ||
| controlnet = component["controlnet"] | ||
| if torch_dtype is not None: | ||
| controlnet = controlnet.to(torch_dtype) | ||
|
|
||
| return controlnet | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
.ckptstill used? If we know that to be "yes" I think it would make sense to add a comment below that ".ckpt" files are also supported.