Skip to content

Commit 7b71e6a

Browse files
takuma104williambermanpcuencadqueuesayakpaul
authored
Add a ControlNet model & pipeline (huggingface#2407)
* add scaffold - copied convert_controlnet_to_diffusers.py from convert_original_stable_diffusion_to_diffusers.py * Add support to load ControlNet (WIP) - this makes Missking Key error on ControlNetModel * Update to convert ControlNet without error msg - init impl for StableDiffusionControlNetPipeline - init impl for ControlNetModel * cleanup of commented out * split create_controlnet_diffusers_config() from create_unet_diffusers_config() - add config: hint_channels * Add input_hint_block, input_zero_conv and middle_block_out - this makes missing key error on loading model * add unet_2d_blocks_controlnet.py - copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D - this makes missing key error on loading model * Add loading for input_hint_block, zero_convs and middle_block_out - this makes no error message on model loading * Copy from UNet2DConditionalModel except __init__ * Add ultra primitive test for ControlNetModel inference * Support ControlNetModel inference - without exceptions * copy forward() from UNet2DConditionModel * Impl ControlledUNet2DConditionModel inference - test_controlled_unet_inference passed * Frozen weight & biases for training * Minimized version of ControlNet/ControlledUnet - test_modules_controllnet.py passed * make style * Add support model loading for minimized ver * Remove all previous version files * from_pretrained and inference test passed * copied from pipeline_stable_diffusion.py except `__init__()` * Impl pipeline, pixel match test (almost) passed. * make style * make fix-copies * Fix to add import ControlNet blocks for `make fix-copies` * Remove einops dependency * Support np.ndarray, PIL.Image for controlnet_hint * set default config file as lllyasviel's * Add support grayscale (hw) numpy array * Add and update docstrings * add control_net.mdx * add control_net.mdx to toctree * Update copyright year * Fix to add PIL.Image RGB->BGR conversion - thanks @Mystfit * make fix-copies * add basic fast test for controlnet * add slow test for controlnet/unet * Ignore down/up_block len check on ControlNet * add a copy from test_stable_diffusion.py * Accept controlnet_hint is None * merge pipeline_stable_diffusion.py diff * Update class name to SDControlNetPipeline * make style * Baseline fast test almost passed (w long desc) * still needs investigate. Following didn't passed descriped in TODO comment: - test_stable_diffusion_long_prompt - test_stable_diffusion_no_safety_checker Following didn't passed same as stable_diffusion_pipeline: - test_attention_slicing_forward_pass - test_inference_batch_single_identical - test_xformers_attention_forwardGenerator_pass these seems come from calc accuracy. * Add note comment related vae_scale_factor * add test_stable_diffusion_controlnet_ddim * add assertion for vae_scale_factor != 8 * slow test of pipeline almost passed Failed: test_stable_diffusion_pipeline_with_model_offloading - ImportError: `enable_model_offload` requires `accelerate v0.17.0` or higher but currently latest version == 0.16.0 * test_stable_diffusion_long_prompt passed * test_stable_diffusion_no_safety_checker passed - due to its model size, move to slow test * remove PoC test files * fix num_of_image, prompt length issue add add test * add support List[PIL.Image] for controlnet_hint * wip * all slow test passed * make style * update for slow test * RGB(PIL)->BGR(ctrlnet) conversion * fixes * remove manual num_images_per_prompt test * add document * add `image` argument docstring * make style * Add line to correct conversion * add controlnet_conditioning_scale (aka control_scales strength) * rgb channel ordering by default * image batching logic * Add control image descriptions for each checkpoint * Only save controlnet model in conversion script * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py typo Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * add gerated image example * a depth mask -> a depth map * rename control_net.mdx to controlnet.mdx * fix toc title * add ControlNet abstruct and link * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: dqueue <[email protected]> * remove controlnet constructor arguments re: @patrickvonplaten * [integration tests] test canny * test_canny fixes * [integration tests] test_depth * [integration tests] test_hed * [integration tests] test_mlsd * add channel order config to controlnet * [integration tests] test normal * [integration tests] test_openpose test_scribble * change height and width to default to conditioning image * [integration tests] test seg * style * test_depth fix * [integration tests] size fixes * [integration tests] cpu offloading * style * generalize controlnet embedding * fix conversion script * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Style adapted to the documentation of pix2pix * merge main by hand * style * [docs] controlling generation doc nits * correct some things * add: controlnetmodel to autodoc. * finish docs * finish * finish 2 * correct images * finish controlnet * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * uP * upload model * up * up --------- Co-authored-by: William Berman <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: dqueue <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7eaf521 commit 7b71e6a

File tree

11 files changed

+1504
-21
lines changed

11 files changed

+1504
-21
lines changed

__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
else:
3535
from .models import (
3636
AutoencoderKL,
37+
ControlNetModel,
3738
ModelMixin,
3839
PriorTransformer,
3940
Transformer2DModel,
@@ -113,6 +114,7 @@
113114
PaintByExamplePipeline,
114115
SemanticStableDiffusionPipeline,
115116
StableDiffusionAttendAndExcitePipeline,
117+
StableDiffusionControlNetPipeline,
116118
StableDiffusionDepth2ImgPipeline,
117119
StableDiffusionImageVariationPipeline,
118120
StableDiffusionImg2ImgPipeline,

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if is_torch_available():
1919
from .autoencoder_kl import AutoencoderKL
20+
from .controlnet import ControlNetModel
2021
from .dual_transformer_2d import DualTransformer2DModel
2122
from .modeling_utils import ModelMixin
2223
from .prior_transformer import PriorTransformer

models/controlnet.py

Lines changed: 506 additions & 0 deletions
Large diffs are not rendered by default.

models/unet_2d_condition.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
num_class_embeds: Optional[int] = None,
143143
upcast_attention: bool = False,
144144
resnet_time_scale_shift: str = "default",
145-
time_embedding_type: str = "positional", # fourier, positional
145+
time_embedding_type: str = "positional",
146146
timestep_post_act: Optional[str] = None,
147147
time_cond_proj_dim: Optional[int] = None,
148148
conv_in_kernel: int = 3,
@@ -492,6 +492,8 @@ def forward(
492492
timestep_cond: Optional[torch.Tensor] = None,
493493
attention_mask: Optional[torch.Tensor] = None,
494494
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
495+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
496+
mid_block_additional_residual: Optional[torch.Tensor] = None,
495497
return_dict: bool = True,
496498
) -> Union[UNet2DConditionOutput, Tuple]:
497499
r"""
@@ -589,6 +591,17 @@ def forward(
589591

590592
down_block_res_samples += res_samples
591593

594+
if down_block_additional_residuals is not None:
595+
new_down_block_res_samples = ()
596+
597+
for down_block_res_sample, down_block_additional_residual in zip(
598+
down_block_res_samples, down_block_additional_residuals
599+
):
600+
down_block_res_sample += down_block_additional_residual
601+
new_down_block_res_samples += (down_block_res_sample,)
602+
603+
down_block_res_samples = new_down_block_res_samples
604+
592605
# 4. mid
593606
if self.mid_block is not None:
594607
sample = self.mid_block(
@@ -599,6 +612,9 @@ def forward(
599612
cross_attention_kwargs=cross_attention_kwargs,
600613
)
601614

615+
if mid_block_additional_residual is not None:
616+
sample += mid_block_additional_residual
617+
602618
# 5. up
603619
for i, upsample_block in enumerate(self.up_blocks):
604620
is_final_block = i == len(self.up_blocks) - 1
@@ -625,6 +641,7 @@ def forward(
625641
sample = upsample_block(
626642
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
627643
)
644+
628645
# 6. post-process
629646
if self.conv_norm_out:
630647
sample = self.conv_norm_out(sample)

pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .stable_diffusion import (
4949
CycleDiffusionPipeline,
5050
StableDiffusionAttendAndExcitePipeline,
51+
StableDiffusionControlNetPipeline,
5152
StableDiffusionDepth2ImgPipeline,
5253
StableDiffusionImageVariationPipeline,
5354
StableDiffusionImg2ImgPipeline,

pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
4545
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
4646
from .pipeline_stable_diffusion import StableDiffusionPipeline
4747
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
48+
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
4849
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
4950
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
5051
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy

pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from diffusers import (
3636
AutoencoderKL,
37+
ControlNetModel,
3738
DDIMScheduler,
3839
DDPMScheduler,
3940
DPMSolverMultistepScheduler,
@@ -44,6 +45,7 @@
4445
LMSDiscreteScheduler,
4546
PNDMScheduler,
4647
PriorTransformer,
48+
StableDiffusionControlNetPipeline,
4749
StableDiffusionPipeline,
4850
StableUnCLIPImg2ImgPipeline,
4951
StableUnCLIPPipeline,
@@ -224,11 +226,15 @@ def conv_attn_to_linear(checkpoint):
224226
checkpoint[key] = checkpoint[key][:, :, 0]
225227

226228

227-
def create_unet_diffusers_config(original_config, image_size: int):
229+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
228230
"""
229231
Creates a config for the diffusers based on the config of the LDM model.
230232
"""
231-
unet_params = original_config.model.params.unet_config.params
233+
if controlnet:
234+
unet_params = original_config.model.params.control_stage_config.params
235+
else:
236+
unet_params = original_config.model.params.unet_config.params
237+
232238
vae_params = original_config.model.params.first_stage_config.params.ddconfig
233239

234240
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
@@ -272,9 +278,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
272278
config = dict(
273279
sample_size=image_size // vae_scale_factor,
274280
in_channels=unet_params.in_channels,
275-
out_channels=unet_params.out_channels,
276281
down_block_types=tuple(down_block_types),
277-
up_block_types=tuple(up_block_types),
278282
block_out_channels=tuple(block_out_channels),
279283
layers_per_block=unet_params.num_res_blocks,
280284
cross_attention_dim=unet_params.context_dim,
@@ -284,6 +288,10 @@ def create_unet_diffusers_config(original_config, image_size: int):
284288
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
285289
)
286290

291+
if not controlnet:
292+
config["out_channels"] = unet_params.out_channels
293+
config["up_block_types"] = tuple(up_block_types)
294+
287295
return config
288296

289297

@@ -331,7 +339,7 @@ def create_ldm_bert_config(original_config):
331339
return config
332340

333341

334-
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
342+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
335343
"""
336344
Takes a state dict and a config, and returns a converted checkpoint.
337345
"""
@@ -340,7 +348,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
340348
unet_state_dict = {}
341349
keys = list(checkpoint.keys())
342350

343-
unet_key = "model.diffusion_model."
351+
if controlnet:
352+
unet_key = "control_model."
353+
else:
354+
unet_key = "model.diffusion_model."
355+
344356
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
345357
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
346358
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
@@ -384,10 +396,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
384396
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
385397
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
386398

387-
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
388-
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
389-
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
390-
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
399+
if not controlnet:
400+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
401+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
402+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
403+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
391404

392405
# Retrieves the keys for the input blocks only
393406
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
@@ -512,6 +525,48 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
512525

513526
new_checkpoint[new_path] = unet_state_dict[old_path]
514527

528+
if controlnet:
529+
# conditioning embedding
530+
531+
orig_index = 0
532+
533+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
534+
f"input_hint_block.{orig_index}.weight"
535+
)
536+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
537+
f"input_hint_block.{orig_index}.bias"
538+
)
539+
540+
orig_index += 2
541+
542+
diffusers_index = 0
543+
544+
while diffusers_index < 6:
545+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
546+
f"input_hint_block.{orig_index}.weight"
547+
)
548+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
549+
f"input_hint_block.{orig_index}.bias"
550+
)
551+
diffusers_index += 1
552+
orig_index += 2
553+
554+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
555+
f"input_hint_block.{orig_index}.weight"
556+
)
557+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
558+
f"input_hint_block.{orig_index}.bias"
559+
)
560+
561+
# down blocks
562+
for i in range(num_input_blocks):
563+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
564+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
565+
566+
# mid block
567+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
568+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
569+
515570
return new_checkpoint
516571

517572

@@ -912,6 +967,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
912967
stable_unclip: Optional[str] = None,
913968
stable_unclip_prior: Optional[str] = None,
914969
clip_stats_path: Optional[str] = None,
970+
controlnet: Optional[bool] = None,
915971
) -> StableDiffusionPipeline:
916972
"""
917973
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1093,6 +1149,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
10931149
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
10941150
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
10951151

1152+
if controlnet is None:
1153+
controlnet = "control_stage_config" in original_config.model.params
1154+
1155+
if controlnet and model_type != "FrozenCLIPEmbedder":
1156+
raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'")
1157+
10961158
if model_type == "FrozenOpenCLIPEmbedder":
10971159
text_model = convert_open_clip_checkpoint(checkpoint)
10981160
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
@@ -1180,15 +1242,41 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
11801242
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
11811243
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
11821244
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1183-
pipe = StableDiffusionPipeline(
1184-
vae=vae,
1185-
text_encoder=text_model,
1186-
tokenizer=tokenizer,
1187-
unet=unet,
1188-
scheduler=scheduler,
1189-
safety_checker=safety_checker,
1190-
feature_extractor=feature_extractor,
1191-
)
1245+
1246+
if controlnet:
1247+
# Convert the ControlNetModel model.
1248+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
1249+
ctrlnet_config["upcast_attention"] = upcast_attention
1250+
1251+
ctrlnet_config.pop("sample_size")
1252+
1253+
controlnet_model = ControlNetModel(**ctrlnet_config)
1254+
1255+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
1256+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
1257+
)
1258+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
1259+
1260+
pipe = StableDiffusionControlNetPipeline(
1261+
vae=vae,
1262+
text_encoder=text_model,
1263+
tokenizer=tokenizer,
1264+
unet=unet,
1265+
controlnet=controlnet_model,
1266+
scheduler=scheduler,
1267+
safety_checker=safety_checker,
1268+
feature_extractor=feature_extractor,
1269+
)
1270+
else:
1271+
pipe = StableDiffusionPipeline(
1272+
vae=vae,
1273+
text_encoder=text_model,
1274+
tokenizer=tokenizer,
1275+
unet=unet,
1276+
scheduler=scheduler,
1277+
safety_checker=safety_checker,
1278+
feature_extractor=feature_extractor,
1279+
)
11921280
else:
11931281
text_config = create_ldm_bert_config(original_config)
11941282
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)

0 commit comments

Comments
 (0)