3535from huggingface_hub .utils import insecure_hashlib
3636from packaging import version
3737from peft import LoraConfig
38- from peft .utils import get_peft_model_state_dict
38+ from peft .utils import get_peft_model_state_dict , set_peft_model_state_dict
3939from PIL import Image
4040from PIL .ImageOps import exif_transpose
4141from torch .utils .data import Dataset
5454)
5555from diffusers .loaders import LoraLoaderMixin
5656from diffusers .optimization import get_scheduler
57- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
57+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params
58+ from diffusers .utils import (
59+ check_min_version ,
60+ convert_state_dict_to_diffusers ,
61+ convert_unet_state_dict_to_peft ,
62+ is_wandb_available ,
63+ )
5864from diffusers .utils .import_utils import is_xformers_available
5965from diffusers .utils .torch_utils import is_compiled_module
6066
@@ -892,10 +898,33 @@ def load_model_hook(models, input_dir):
892898 raise ValueError (f"unexpected save model: { model .__class__ } " )
893899
894900 lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
895- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
896- LoraLoaderMixin .load_lora_into_text_encoder (
897- lora_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_
898- )
901+
902+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
903+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
904+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
905+
906+ if incompatible_keys is not None :
907+ # check only for unexpected keys
908+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
909+ if unexpected_keys :
910+ logger .warning (
911+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
912+ f" { unexpected_keys } . "
913+ )
914+
915+ if args .train_text_encoder :
916+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_ )
917+
918+ # Make sure the trainable params are in float32. This is again needed since the base models
919+ # are in `weight_dtype`. More details:
920+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
921+ if args .mixed_precision == "fp16" :
922+ models = [unet_ ]
923+ if args .train_text_encoder :
924+ models .append (text_encoder_ )
925+
926+ # only upcast trainable parameters (LoRA) into fp32
927+ cast_training_params (models , dtype = torch .float32 )
899928
900929 accelerator .register_save_state_pre_hook (save_model_hook )
901930 accelerator .register_load_state_pre_hook (load_model_hook )
@@ -910,6 +939,15 @@ def load_model_hook(models, input_dir):
910939 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
911940 )
912941
942+ # Make sure the trainable params are in float32.
943+ if args .mixed_precision == "fp16" :
944+ models = [unet ]
945+ if args .train_text_encoder :
946+ models .append (text_encoder )
947+
948+ # only upcast trainable parameters (LoRA) into fp32
949+ cast_training_params (models , dtype = torch .float32 )
950+
913951 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
914952 if args .use_8bit_adam :
915953 try :
0 commit comments