diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py index bc612edbc20e..ca2a1521d3d5 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -1470,7 +1470,15 @@ def __call__( height, width = self._default_height_width(height, width, adapter_image) device = self._execution_device - adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device) + if isinstance(adapter, MultiAdapter): + adapter_input = [] + for one_image in adapter_image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(adapter_image, height, width) + adapter_input = adapter_input.to(device=device, dtype=adapter.dtype) original_size = original_size or (height, width) target_size = target_size or (height, width) @@ -1643,10 +1651,14 @@ def denoising_value_valid(dnv): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Prepare added time ids & embeddings & adapter features - adapter_input = adapter_input.type(latents.dtype) - adapter_state = adapter(adapter_input) - for k, v in enumerate(adapter_state): - adapter_state[k] = v * adapter_conditioning_scale + if isinstance(adapter, MultiAdapter): + adapter_state = adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)