@@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
912912 )
913913
914914 if unet_lora_layers :
915- state_dict .update (pack_weights (unet_lora_layers , "unet" ))
915+ state_dict .update (pack_weights (unet_lora_layers , cls . unet_name ))
916916
917917 if text_encoder_lora_layers :
918- state_dict .update (pack_weights (text_encoder_lora_layers , "text_encoder" ))
918+ state_dict .update (pack_weights (text_encoder_lora_layers , cls . text_encoder_name ))
919919
920920 if transformer_lora_layers :
921921 state_dict .update (pack_weights (transformer_lora_layers , "transformer" ))
@@ -975,20 +975,22 @@ def unload_lora_weights(self):
975975 >>> ...
976976 ```
977977 """
978+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
979+
978980 if not USE_PEFT_BACKEND :
979981 if version .parse (__version__ ) > version .parse ("0.23" ):
980982 logger .warn (
981983 "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
982984 "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
983985 )
984986
985- for _ , module in self . unet .named_modules ():
987+ for _ , module in unet .named_modules ():
986988 if hasattr (module , "set_lora_layer" ):
987989 module .set_lora_layer (None )
988990 else :
989- recurse_remove_peft_layers (self . unet )
990- if hasattr (self . unet , "peft_config" ):
991- del self . unet .peft_config
991+ recurse_remove_peft_layers (unet )
992+ if hasattr (unet , "peft_config" ):
993+ del unet .peft_config
992994
993995 # Safe to call the following regardless of LoRA.
994996 self ._remove_text_encoder_monkey_patch ()
@@ -1027,7 +1029,8 @@ def fuse_lora(
10271029 )
10281030
10291031 if fuse_unet :
1030- self .unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
1032+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1033+ unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
10311034
10321035 if USE_PEFT_BACKEND :
10331036 from peft .tuners .tuners_utils import BaseTunerLayer
@@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
10801083 Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
10811084 LoRA parameters then it won't have any effect.
10821085 """
1086+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
10831087 if unfuse_unet :
10841088 if not USE_PEFT_BACKEND :
1085- self . unet .unfuse_lora ()
1089+ unet .unfuse_lora ()
10861090 else :
10871091 from peft .tuners .tuners_utils import BaseTunerLayer
10881092
1089- for module in self . unet .modules ():
1093+ for module in unet .modules ():
10901094 if isinstance (module , BaseTunerLayer ):
10911095 module .unmerge ()
10921096
@@ -1202,8 +1206,9 @@ def set_adapters(
12021206 adapter_names : Union [List [str ], str ],
12031207 adapter_weights : Optional [List [float ]] = None ,
12041208 ):
1209+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
12051210 # Handle the UNET
1206- self . unet .set_adapters (adapter_names , adapter_weights )
1211+ unet .set_adapters (adapter_names , adapter_weights )
12071212
12081213 # Handle the Text Encoder
12091214 if hasattr (self , "text_encoder" ):
@@ -1216,7 +1221,8 @@ def disable_lora(self):
12161221 raise ValueError ("PEFT backend is required for this method." )
12171222
12181223 # Disable unet adapters
1219- self .unet .disable_lora ()
1224+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1225+ unet .disable_lora ()
12201226
12211227 # Disable text encoder adapters
12221228 if hasattr (self , "text_encoder" ):
@@ -1229,7 +1235,8 @@ def enable_lora(self):
12291235 raise ValueError ("PEFT backend is required for this method." )
12301236
12311237 # Enable unet adapters
1232- self .unet .enable_lora ()
1238+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1239+ unet .enable_lora ()
12331240
12341241 # Enable text encoder adapters
12351242 if hasattr (self , "text_encoder" ):
@@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
12511258 adapter_names = [adapter_names ]
12521259
12531260 # Delete unet adapters
1254- self .unet .delete_adapters (adapter_names )
1261+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1262+ unet .delete_adapters (adapter_names )
12551263
12561264 for adapter_name in adapter_names :
12571265 # Delete text encoder adapters
@@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]:
12841292 from peft .tuners .tuners_utils import BaseTunerLayer
12851293
12861294 active_adapters = []
1287-
1288- for module in self . unet .modules ():
1295+ unet = getattr ( self , self . unet_name ) if not hasattr ( self , "unet" ) else self . unet
1296+ for module in unet .modules ():
12891297 if isinstance (module , BaseTunerLayer ):
12901298 active_adapters = module .active_adapters
12911299 break
@@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
13091317 if hasattr (self , "text_encoder_2" ) and hasattr (self .text_encoder_2 , "peft_config" ):
13101318 set_adapters ["text_encoder_2" ] = list (self .text_encoder_2 .peft_config .keys ())
13111319
1312- if hasattr (self , "unet" ) and hasattr (self .unet , "peft_config" ):
1313- set_adapters ["unet" ] = list (self .unet .peft_config .keys ())
1320+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1321+ if hasattr (self , self .unet_name ) and hasattr (unet , "peft_config" ):
1322+ set_adapters [self .unet_name ] = list (self .unet .peft_config .keys ())
13141323
13151324 return set_adapters
13161325
@@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
13311340 from peft .tuners .tuners_utils import BaseTunerLayer
13321341
13331342 # Handle the UNET
1334- for unet_module in self .unet .modules ():
1343+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1344+ for unet_module in unet .modules ():
13351345 if isinstance (unet_module , BaseTunerLayer ):
13361346 for adapter_name in adapter_names :
13371347 unet_module .lora_A [adapter_name ].to (device )
0 commit comments