-
Notifications
You must be signed in to change notification settings - Fork 6.5k
LCM Add Tests #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LCM Add Tests #5707
Changes from 6 commits
4b8433f
5f528f6
bfc6831
d795100
4fc6cfb
52b1846
95b6430
1225d73
fc0ead3
92bba31
40d6e37
e1aa6b1
65b6e34
f4594cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # limitations under the License. | ||
| import os | ||
| import re | ||
| import copy | ||
| from collections import defaultdict | ||
| from contextlib import nullcontext | ||
| from io import BytesIO | ||
|
|
@@ -44,6 +45,7 @@ | |
| is_transformers_available, | ||
| logging, | ||
| recurse_remove_peft_layers, | ||
| find_adapter_config_file, | ||
| scale_lora_layers, | ||
| set_adapter_layers, | ||
| set_weights_and_activate_adapters, | ||
|
|
@@ -1196,8 +1198,18 @@ def load_lora_weights( | |
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
| `default_{i}` where i is the total number of adapters being loaded. | ||
| """ | ||
| # let's copy the kwargs so that we can pass them to `load_lora_into_unet` | ||
| peft_kwargs = copy.deepcopy(kwargs) | ||
|
|
||
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||
| state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
| state_dict, network_alphas = self.lora_state_dict( | ||
| pretrained_model_name_or_path_or_dict, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| peft_config = None | ||
| if USE_PEFT_BACKEND and not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| peft_config = self._load_peft_config(pretrained_model_name_or_path_or_dict, **peft_kwargs) | ||
|
|
||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
| if not is_correct_format: | ||
|
|
@@ -1211,6 +1223,7 @@ def load_lora_weights( | |
| unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, | ||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||
| adapter_name=adapter_name, | ||
| peft_config=peft_config, | ||
| _pipeline=self, | ||
| ) | ||
| self.load_lora_into_text_encoder( | ||
|
|
@@ -1325,6 +1338,7 @@ def lora_state_dict( | |
| weight_name = cls._best_guess_weight_name( | ||
| pretrained_model_name_or_path_or_dict, file_extension=".safetensors" | ||
| ) | ||
| print("weight_name", weight_name) | ||
|
||
| model_file = _get_model_file( | ||
| pretrained_model_name_or_path_or_dict, | ||
| weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, | ||
|
|
@@ -1387,6 +1401,32 @@ def lora_state_dict( | |
|
|
||
| return state_dict, network_alphas | ||
|
|
||
| @classmethod | ||
| def _load_peft_config(cls, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||
| cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | ||
| force_download = kwargs.pop("force_download", False) | ||
| resume_download = kwargs.pop("resume_download", False) | ||
| proxies = kwargs.pop("proxies", None) | ||
| local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | ||
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
|
|
||
| user_agent = {"library": "diffusers-peft"} | ||
| peft_config = find_adapter_config_file( | ||
| pretrained_model_name_or_path, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| return peft_config | ||
|
|
||
| @classmethod | ||
| def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"): | ||
| targeted_files = [] | ||
|
|
@@ -1411,6 +1451,15 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext | |
| filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) | ||
| ) | ||
|
|
||
| if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): | ||
| targeted_files = list( | ||
| filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files) | ||
| ) | ||
| elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): | ||
| targeted_files = list( | ||
| filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files) | ||
| ) | ||
|
||
|
|
||
| if len(targeted_files) > 1: | ||
| raise ValueError( | ||
| f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." | ||
|
|
@@ -1554,7 +1603,7 @@ def _optionally_disable_offloading(cls, _pipeline): | |
|
|
||
| @classmethod | ||
| def load_lora_into_unet( | ||
| cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None | ||
| cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, peft_config=None, _pipeline=None | ||
| ): | ||
| """ | ||
| This will load the LoRA layers specified in `state_dict` into `unet`. | ||
|
|
@@ -1622,7 +1671,11 @@ def load_lora_into_unet( | |
| if "lora_B" in key: | ||
| rank[key] = val.shape[1] | ||
|
|
||
| lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) | ||
| if peft_config is not None: | ||
| lora_config_kwargs = LoraConfig.from_json_file(peft_config) | ||
| else: | ||
| lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) | ||
|
|
||
|
||
| lora_config = LoraConfig(**lora_config_kwargs) | ||
|
|
||
| # adapter_name | ||
|
|
@@ -3211,6 +3264,9 @@ def load_lora_weights( | |
| kwargs (`dict`, *optional*): | ||
| See [`~loaders.LoraLoaderMixin.lora_state_dict`]. | ||
| """ | ||
| # let's copy the kwargs so that we can pass them to `load_lora_into_unet` | ||
| peft_kwargs = copy.deepcopy(kwargs) | ||
|
|
||
| # We could have accessed the unet config from `lora_state_dict()` too. We pass | ||
| # it here explicitly to be able to tell that it's coming from an SDXL | ||
| # pipeline. | ||
|
|
@@ -3221,12 +3277,17 @@ def load_lora_weights( | |
| unet_config=self.unet.config, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| peft_config = None | ||
| if USE_PEFT_BACKEND and not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| peft_config = self._load_peft_config(pretrained_model_name_or_path_or_dict, **peft_kwargs) | ||
|
|
||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
| if not is_correct_format: | ||
| raise ValueError("Invalid LoRA checkpoint.") | ||
|
|
||
| self.load_lora_into_unet( | ||
| state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self | ||
| state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, peft_config=peft_config, _pipeline=self | ||
| ) | ||
| text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||
| if len(text_encoder_state_dict) > 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,13 +16,18 @@ | |
| """ | ||
| import collections | ||
| import importlib | ||
| from typing import Optional | ||
| import os | ||
| from typing import Optional, Dict, Union | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from packaging import version | ||
|
|
||
| from .import_utils import is_peft_available, is_torch_available | ||
|
|
||
|
|
||
| ADAPTER_CONFIG_NAME = "adapter_config.json" | ||
|
|
||
|
|
||
| def recurse_remove_peft_layers(model): | ||
| if is_torch_available(): | ||
| import torch | ||
|
|
@@ -204,6 +209,81 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): | |
| module.active_adapter = adapter_names | ||
|
|
||
|
|
||
| def find_adapter_config_file( | ||
|
||
| model_id: str, | ||
| cache_dir: Optional[Union[str, os.PathLike]] = None, | ||
| force_download: bool = False, | ||
| resume_download: bool = False, | ||
| proxies: Optional[Dict[str, str]] = None, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| revision: Optional[str] = None, | ||
| user_agent: Optional[Dict[str, str]] = None, | ||
| local_files_only: bool = False, | ||
| subfolder: str = "", | ||
| ) -> Optional[str]: | ||
| r""" | ||
| Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter | ||
| config file if it is, None otherwise. | ||
|
|
||
| Args: | ||
| model_id (`str`): | ||
| The identifier of the model to look for, can be either a local path or an id to the repository on the Hub. | ||
| cache_dir (`str` or `os.PathLike`, *optional*): | ||
| Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | ||
| cache should not be used. | ||
| force_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to force to (re-)download the configuration files and override the cached versions if they | ||
| exist. | ||
| resume_download (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | ||
| proxies (`Dict[str, str]`, *optional*): | ||
| A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | ||
| 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | ||
| use_auth_token (`str` or *bool*, *optional*): | ||
| The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated | ||
| when running `huggingface-cli login` (stored in `~/.huggingface`). | ||
| revision (`str`, *optional*, defaults to `"main"`): | ||
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | ||
| git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | ||
| identifier allowed by git. | ||
|
|
||
| <Tip> | ||
|
|
||
| To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>". | ||
|
|
||
| </Tip> | ||
|
|
||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| If `True`, will only try to load the tokenizer configuration from local files. | ||
| subfolder (`str`, *optional*, defaults to `""`): | ||
| In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | ||
| specify the folder name here. | ||
| """ | ||
| adapter_cached_filename = None | ||
| if model_id is None: | ||
| return None | ||
| elif os.path.isdir(model_id): | ||
| list_remote_files = os.listdir(model_id) | ||
| if ADAPTER_CONFIG_NAME in list_remote_files: | ||
| adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME) | ||
| else: | ||
| adapter_cached_filename = hf_hub_download( | ||
| model_id, | ||
| ADAPTER_CONFIG_NAME, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| resume_download=resume_download, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| user_agent=user_agent, | ||
| ) | ||
|
|
||
| return adapter_cached_filename | ||
|
|
||
|
|
||
| def check_peft_version(min_version: str) -> None: | ||
| r""" | ||
| Checks if the version of PEFT is compatible. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove the last ","m this become one-line.