Skip to content

Commit 93f1a14

Browse files
affromeroandres
andauthored
ControlNet+Adapter pipeline, and ControlNet+Adapter+Inpaint pipeline (#5869)
* ControlNet+Adapter pipeline, and +Inpaint pipeline --------- Co-authored-by: andres <[email protected]>
1 parent 13d73d9 commit 93f1a14

File tree

3 files changed

+3497
-0
lines changed

3 files changed

+3497
-0
lines changed

examples/community/README.md

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,3 +2343,141 @@ images = pipe(
23432343

23442344
assert len(images) == (len(prompts) - 1) * num_interpolation_steps
23452345
```
2346+
2347+
### ControlNet + T2I Adapter Pipeline
2348+
This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
2349+
It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
2350+
2351+
```py
2352+
import cv2
2353+
import numpy as np
2354+
import torch
2355+
from controlnet_aux.midas import MidasDetector
2356+
from PIL import Image
2357+
2358+
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
2359+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
2360+
from diffusers.utils import load_image
2361+
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
2362+
StableDiffusionXLControlNetAdapterPipeline,
2363+
)
2364+
2365+
controlnet_depth = ControlNetModel.from_pretrained(
2366+
"diffusers/controlnet-depth-sdxl-1.0",
2367+
torch_dtype=torch.float16,
2368+
variant="fp16",
2369+
use_safetensors=True
2370+
)
2371+
adapter_depth = T2IAdapter.from_pretrained(
2372+
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
2373+
)
2374+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
2375+
2376+
pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
2377+
"stabilityai/stable-diffusion-xl-base-1.0",
2378+
controlnet=controlnet_depth,
2379+
adapter=adapter_depth,
2380+
vae=vae,
2381+
variant="fp16",
2382+
use_safetensors=True,
2383+
torch_dtype=torch.float16,
2384+
)
2385+
pipe = pipe.to("cuda")
2386+
pipe.enable_xformers_memory_efficient_attention()
2387+
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
2388+
midas_depth = MidasDetector.from_pretrained(
2389+
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
2390+
).to("cuda")
2391+
2392+
prompt = "a tiger sitting on a park bench"
2393+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
2394+
2395+
image = load_image(img_url).resize((1024, 1024))
2396+
2397+
depth_image = midas_depth(
2398+
image, detect_resolution=512, image_resolution=1024
2399+
)
2400+
2401+
strength = 0.5
2402+
2403+
images = pipe(
2404+
prompt,
2405+
control_image=depth_image,
2406+
adapter_image=depth_image,
2407+
num_inference_steps=30,
2408+
controlnet_conditioning_scale=strength,
2409+
adapter_conditioning_scale=strength,
2410+
).images
2411+
images[0].save("controlnet_and_adapter.png")
2412+
2413+
```
2414+
2415+
### ControlNet + T2I Adapter + Inpainting Pipeline
2416+
```py
2417+
import cv2
2418+
import numpy as np
2419+
import torch
2420+
from controlnet_aux.midas import MidasDetector
2421+
from PIL import Image
2422+
2423+
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
2424+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
2425+
from diffusers.utils import load_image
2426+
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
2427+
StableDiffusionXLControlNetAdapterInpaintPipeline,
2428+
)
2429+
2430+
controlnet_depth = ControlNetModel.from_pretrained(
2431+
"diffusers/controlnet-depth-sdxl-1.0",
2432+
torch_dtype=torch.float16,
2433+
variant="fp16",
2434+
use_safetensors=True
2435+
)
2436+
adapter_depth = T2IAdapter.from_pretrained(
2437+
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
2438+
)
2439+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
2440+
2441+
pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
2442+
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
2443+
controlnet=controlnet_depth,
2444+
adapter=adapter_depth,
2445+
vae=vae,
2446+
variant="fp16",
2447+
use_safetensors=True,
2448+
torch_dtype=torch.float16,
2449+
)
2450+
pipe = pipe.to("cuda")
2451+
pipe.enable_xformers_memory_efficient_attention()
2452+
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
2453+
midas_depth = MidasDetector.from_pretrained(
2454+
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
2455+
).to("cuda")
2456+
2457+
prompt = "a tiger sitting on a park bench"
2458+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
2459+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
2460+
2461+
image = load_image(img_url).resize((1024, 1024))
2462+
mask_image = load_image(mask_url).resize((1024, 1024))
2463+
2464+
depth_image = midas_depth(
2465+
image, detect_resolution=512, image_resolution=1024
2466+
)
2467+
2468+
strength = 0.4
2469+
2470+
images = pipe(
2471+
prompt,
2472+
image=image,
2473+
mask_image=mask_image,
2474+
control_image=depth_image,
2475+
adapter_image=depth_image,
2476+
num_inference_steps=30,
2477+
controlnet_conditioning_scale=strength,
2478+
adapter_conditioning_scale=strength,
2479+
strength=0.7,
2480+
).images
2481+
images[0].save("controlnet_and_adapter_inpaint.png")
2482+
2483+
```

0 commit comments

Comments
 (0)