diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 89bb498a3acd..e124b6eeacf3 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -532,13 +532,19 @@ def set_adapters( ) list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} - all_adapters = { - adapter for adapters in list_adapters.values() for adapter in adapters - } # eg ["adapter1", "adapter2"] + # eg ["adapter1", "adapter2"] + all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} + missing_adapters = set(adapter_names) - all_adapters + if len(missing_adapters) > 0: + raise ValueError( + f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." + ) + + # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} invert_list_adapters = { adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] for adapter in all_adapters - } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + } # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 939b749c286a..43c45daaa322 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -929,12 +929,24 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) # Fuse and unfuse should lead to the same results self.assertFalse( @@ -960,6 +972,38 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): "output with no lora and output with lora disabled should give same results", ) + def test_wrong_adapter_name_raises_error(self): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + with self.assertRaises(ValueError) as err_context: + pipe.set_adapters("test") + + self.assertTrue("not in the list of present adapters" in str(err_context.exception)) + + # test this works. + pipe.set_adapters("adapter-1") + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches