@@ -3254,6 +3254,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
32543254 super ().unfuse_lora (components = components )
32553255
32563256
3257+ class LTXVideoLoraLoaderMixin (LoraBaseMixin ):
3258+ r"""
3259+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
3260+ """
3261+
3262+ _lora_loadable_modules = ["transformer" ]
3263+ transformer_name = TRANSFORMER_NAME
3264+
3265+ @classmethod
3266+ @validate_hf_hub_args
3267+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3268+ def lora_state_dict (
3269+ cls ,
3270+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3271+ ** kwargs ,
3272+ ):
3273+ r"""
3274+ Return state dict for lora weights and the network alphas.
3275+
3276+ <Tip warning={true}>
3277+
3278+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3279+
3280+ This function is experimental and might change in the future.
3281+
3282+ </Tip>
3283+
3284+ Parameters:
3285+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3286+ Can be either:
3287+
3288+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3289+ the Hub.
3290+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3291+ with [`ModelMixin.save_pretrained`].
3292+ - A [torch state
3293+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3294+
3295+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3296+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3297+ is not used.
3298+ force_download (`bool`, *optional*, defaults to `False`):
3299+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3300+ cached versions if they exist.
3301+
3302+ proxies (`Dict[str, str]`, *optional*):
3303+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3304+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3305+ local_files_only (`bool`, *optional*, defaults to `False`):
3306+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3307+ won't be downloaded from the Hub.
3308+ token (`str` or *bool*, *optional*):
3309+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3310+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3311+ revision (`str`, *optional*, defaults to `"main"`):
3312+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3313+ allowed by Git.
3314+ subfolder (`str`, *optional*, defaults to `""`):
3315+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3316+
3317+ """
3318+ # Load the main state dict first which has the LoRA layers for either of
3319+ # transformer and text encoder or both.
3320+ cache_dir = kwargs .pop ("cache_dir" , None )
3321+ force_download = kwargs .pop ("force_download" , False )
3322+ proxies = kwargs .pop ("proxies" , None )
3323+ local_files_only = kwargs .pop ("local_files_only" , None )
3324+ token = kwargs .pop ("token" , None )
3325+ revision = kwargs .pop ("revision" , None )
3326+ subfolder = kwargs .pop ("subfolder" , None )
3327+ weight_name = kwargs .pop ("weight_name" , None )
3328+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3329+
3330+ allow_pickle = False
3331+ if use_safetensors is None :
3332+ use_safetensors = True
3333+ allow_pickle = True
3334+
3335+ user_agent = {
3336+ "file_type" : "attn_procs_weights" ,
3337+ "framework" : "pytorch" ,
3338+ }
3339+
3340+ state_dict = _fetch_state_dict (
3341+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3342+ weight_name = weight_name ,
3343+ use_safetensors = use_safetensors ,
3344+ local_files_only = local_files_only ,
3345+ cache_dir = cache_dir ,
3346+ force_download = force_download ,
3347+ proxies = proxies ,
3348+ token = token ,
3349+ revision = revision ,
3350+ subfolder = subfolder ,
3351+ user_agent = user_agent ,
3352+ allow_pickle = allow_pickle ,
3353+ )
3354+
3355+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3356+ if is_dora_scale_present :
3357+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3358+ logger .warning (warn_msg )
3359+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3360+
3361+ return state_dict
3362+
3363+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3364+ def load_lora_weights (
3365+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3366+ ):
3367+ """
3368+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3369+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3370+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3371+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3372+ dict is loaded into `self.transformer`.
3373+
3374+ Parameters:
3375+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3376+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3377+ adapter_name (`str`, *optional*):
3378+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3379+ `default_{i}` where i is the total number of adapters being loaded.
3380+ low_cpu_mem_usage (`bool`, *optional*):
3381+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3382+ weights.
3383+ kwargs (`dict`, *optional*):
3384+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3385+ """
3386+ if not USE_PEFT_BACKEND :
3387+ raise ValueError ("PEFT backend is required for this method." )
3388+
3389+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3390+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3391+ raise ValueError (
3392+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3393+ )
3394+
3395+ # if a dict is passed, copy it instead of modifying it inplace
3396+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
3397+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
3398+
3399+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3400+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3401+
3402+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
3403+ if not is_correct_format :
3404+ raise ValueError ("Invalid LoRA checkpoint." )
3405+
3406+ self .load_lora_into_transformer (
3407+ state_dict ,
3408+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
3409+ adapter_name = adapter_name ,
3410+ _pipeline = self ,
3411+ low_cpu_mem_usage = low_cpu_mem_usage ,
3412+ )
3413+
3414+ @classmethod
3415+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3416+ def load_lora_into_transformer (
3417+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
3418+ ):
3419+ """
3420+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3421+
3422+ Parameters:
3423+ state_dict (`dict`):
3424+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3425+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3426+ encoder lora layers.
3427+ transformer (`LTXVideoTransformer3DModel`):
3428+ The Transformer model to load the LoRA layers into.
3429+ adapter_name (`str`, *optional*):
3430+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3431+ `default_{i}` where i is the total number of adapters being loaded.
3432+ low_cpu_mem_usage (`bool`, *optional*):
3433+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3434+ weights.
3435+ """
3436+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3437+ raise ValueError (
3438+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3439+ )
3440+
3441+ # Load the layers corresponding to transformer.
3442+ logger .info (f"Loading { cls .transformer_name } ." )
3443+ transformer .load_lora_adapter (
3444+ state_dict ,
3445+ network_alphas = None ,
3446+ adapter_name = adapter_name ,
3447+ _pipeline = _pipeline ,
3448+ low_cpu_mem_usage = low_cpu_mem_usage ,
3449+ )
3450+
3451+ @classmethod
3452+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3453+ def save_lora_weights (
3454+ cls ,
3455+ save_directory : Union [str , os .PathLike ],
3456+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
3457+ is_main_process : bool = True ,
3458+ weight_name : str = None ,
3459+ save_function : Callable = None ,
3460+ safe_serialization : bool = True ,
3461+ ):
3462+ r"""
3463+ Save the LoRA parameters corresponding to the UNet and text encoder.
3464+
3465+ Arguments:
3466+ save_directory (`str` or `os.PathLike`):
3467+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3468+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3469+ State dict of the LoRA layers corresponding to the `transformer`.
3470+ is_main_process (`bool`, *optional*, defaults to `True`):
3471+ Whether the process calling this is the main process or not. Useful during distributed training and you
3472+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3473+ process to avoid race conditions.
3474+ save_function (`Callable`):
3475+ The function to use to save the state dictionary. Useful during distributed training when you need to
3476+ replace `torch.save` with another method. Can be configured with the environment variable
3477+ `DIFFUSERS_SAVE_MODE`.
3478+ safe_serialization (`bool`, *optional*, defaults to `True`):
3479+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3480+ """
3481+ state_dict = {}
3482+
3483+ if not transformer_lora_layers :
3484+ raise ValueError ("You must pass `transformer_lora_layers`." )
3485+
3486+ if transformer_lora_layers :
3487+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
3488+
3489+ # Save the model
3490+ cls .write_lora_layers (
3491+ state_dict = state_dict ,
3492+ save_directory = save_directory ,
3493+ is_main_process = is_main_process ,
3494+ weight_name = weight_name ,
3495+ save_function = save_function ,
3496+ safe_serialization = safe_serialization ,
3497+ )
3498+
3499+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3500+ def fuse_lora (
3501+ self ,
3502+ components : List [str ] = ["transformer" , "text_encoder" ],
3503+ lora_scale : float = 1.0 ,
3504+ safe_fusing : bool = False ,
3505+ adapter_names : Optional [List [str ]] = None ,
3506+ ** kwargs ,
3507+ ):
3508+ r"""
3509+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3510+
3511+ <Tip warning={true}>
3512+
3513+ This is an experimental API.
3514+
3515+ </Tip>
3516+
3517+ Args:
3518+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3519+ lora_scale (`float`, defaults to 1.0):
3520+ Controls how much to influence the outputs with the LoRA parameters.
3521+ safe_fusing (`bool`, defaults to `False`):
3522+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3523+ adapter_names (`List[str]`, *optional*):
3524+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3525+
3526+ Example:
3527+
3528+ ```py
3529+ from diffusers import DiffusionPipeline
3530+ import torch
3531+
3532+ pipeline = DiffusionPipeline.from_pretrained(
3533+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3534+ ).to("cuda")
3535+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3536+ pipeline.fuse_lora(lora_scale=0.7)
3537+ ```
3538+ """
3539+ super ().fuse_lora (
3540+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3541+ )
3542+
3543+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3544+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
3545+ r"""
3546+ Reverses the effect of
3547+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3548+
3549+ <Tip warning={true}>
3550+
3551+ This is an experimental API.
3552+
3553+ </Tip>
3554+
3555+ Args:
3556+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3557+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3558+ unfuse_text_encoder (`bool`, defaults to `True`):
3559+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3560+ LoRA parameters then it won't have any effect.
3561+ """
3562+ super ().unfuse_lora (components = components )
3563+
3564+
32573565class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
32583566 def __init__ (self , * args , ** kwargs ):
32593567 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments