Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import re
import copy
from collections import defaultdict
from contextlib import nullcontext
from io import BytesIO
Expand Down Expand Up @@ -44,6 +45,7 @@
is_transformers_available,
logging,
recurse_remove_peft_layers,
find_adapter_config_file,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
Expand Down Expand Up @@ -1196,8 +1198,18 @@ def load_lora_weights(
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
# let's copy the kwargs so that we can pass them to `load_lora_into_unet`
peft_kwargs = copy.deepcopy(kwargs)

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
**kwargs,
)
Copy link
Member

@sayakpaul sayakpaul Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the last ","m this become one-line.


peft_config = None
if USE_PEFT_BACKEND and not isinstance(pretrained_model_name_or_path_or_dict, dict):
peft_config = self._load_peft_config(pretrained_model_name_or_path_or_dict, **peft_kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
Expand All @@ -1211,6 +1223,7 @@ def load_lora_weights(
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
peft_config=peft_config,
_pipeline=self,
)
self.load_lora_into_text_encoder(
Expand Down Expand Up @@ -1325,6 +1338,7 @@ def lora_state_dict(
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
)
print("weight_name", weight_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove.

model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
Expand Down Expand Up @@ -1387,6 +1401,32 @@ def lora_state_dict(

return state_dict, network_alphas

@classmethod
def _load_peft_config(cls, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)

user_agent = {"library": "diffusers-peft"}
peft_config = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
return peft_config

@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
targeted_files = []
Expand All @@ -1411,6 +1451,15 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)

if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(
filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)
)
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(
filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
Expand Down Expand Up @@ -1554,7 +1603,7 @@ def _optionally_disable_offloading(cls, _pipeline):

@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, peft_config=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand Down Expand Up @@ -1622,7 +1671,11 @@ def load_lora_into_unet(
if "lora_B" in key:
rank[key] = val.shape[1]

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if peft_config is not None:
lora_config_kwargs = LoraConfig.from_json_file(peft_config)
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this come under an if USE_PEFT_BACKEND?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down Expand Up @@ -3211,6 +3264,9 @@ def load_lora_weights(
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# let's copy the kwargs so that we can pass them to `load_lora_into_unet`
peft_kwargs = copy.deepcopy(kwargs)

# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
Expand All @@ -3221,12 +3277,17 @@ def load_lora_weights(
unet_config=self.unet.config,
**kwargs,
)

peft_config = None
if USE_PEFT_BACKEND and not isinstance(pretrained_model_name_or_path_or_dict, dict):
peft_config = self._load_peft_config(pretrained_model_name_or_path_or_dict, **peft_kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, peft_config=peft_config, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,36 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings

Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0

half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb


def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

Expand Down Expand Up @@ -593,7 +623,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

@property
def cross_attention_kwargs(self):
Expand Down Expand Up @@ -790,6 +820,14 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(
device=device, dtype=latents.dtype
)

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -804,6 +842,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,35 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings

Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0

half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
Expand All @@ -653,7 +682,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

@property
def cross_attention_kwargs(self):
Expand Down Expand Up @@ -989,6 +1018,14 @@ def __call__(
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

# 9. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(
device=device, dtype=latents.dtype
)

self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand All @@ -1003,6 +1040,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
set_adapter_layers,
set_weights_and_activate_adapters,
unscale_lora_layers,
find_adapter_config_file,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import (
Expand Down
82 changes: 81 additions & 1 deletion src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
"""
import collections
import importlib
from typing import Optional
import os
from typing import Optional, Dict, Union
from huggingface_hub import hf_hub_download

from packaging import version

from .import_utils import is_peft_available, is_torch_available


ADAPTER_CONFIG_NAME = "adapter_config.json"


def recurse_remove_peft_layers(model):
if is_torch_available():
import torch
Expand Down Expand Up @@ -204,6 +209,81 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.active_adapter = adapter_names


def find_adapter_config_file(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very useful!

model_id: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
user_agent: Optional[Dict[str, str]] = None,
local_files_only: bool = False,
subfolder: str = "",
) -> Optional[str]:
r"""
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
config file if it is, None otherwise.

Args:
model_id (`str`):
The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.

<Tip>

To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".

</Tip>

local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
"""
adapter_cached_filename = None
if model_id is None:
return None
elif os.path.isdir(model_id):
list_remote_files = os.listdir(model_id)
if ADAPTER_CONFIG_NAME in list_remote_files:
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
else:
adapter_cached_filename = hf_hub_download(
model_id,
ADAPTER_CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
subfolder=subfolder,
revision=revision,
user_agent=user_agent,
)

return adapter_cached_filename


def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
Expand Down
Loading