Skip to content

Commit db71a80

Browse files
a-r-r-o-wpatrickvonplaten
authored andcommitted
Addition of new callbacks to controlnets (huggingface#5812)
* add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py * update callbacks * fix repeated kwarg * update --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent aadbce5 commit db71a80

File tree

6 files changed

+510
-133
lines changed

6 files changed

+510
-133
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline(
130130
model_cpu_offload_seq = "text_encoder->unet->vae"
131131
_optional_components = ["safety_checker", "feature_extractor"]
132132
_exclude_from_cpu_offload = ["safety_checker"]
133+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
133134

134135
def __init__(
135136
self,
@@ -485,15 +486,21 @@ def check_inputs(
485486
controlnet_conditioning_scale=1.0,
486487
control_guidance_start=0.0,
487488
control_guidance_end=1.0,
489+
callback_on_step_end_tensor_inputs=None,
488490
):
489-
if (callback_steps is None) or (
490-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
491-
):
491+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
492492
raise ValueError(
493493
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
494494
f" {type(callback_steps)}."
495495
)
496496

497+
if callback_on_step_end_tensor_inputs is not None and not all(
498+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
499+
):
500+
raise ValueError(
501+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
502+
)
503+
497504
if prompt is not None and prompt_embeds is not None:
498505
raise ValueError(
499506
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -760,13 +767,25 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32
760767
def guidance_scale(self):
761768
return self._guidance_scale
762769

770+
@property
771+
def clip_skip(self):
772+
return self._clip_skip
773+
763774
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
764775
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
765776
# corresponds to doing no classifier free guidance.
766777
@property
767778
def do_classifier_free_guidance(self):
768779
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
769780

781+
@property
782+
def cross_attention_kwargs(self):
783+
return self._cross_attention_kwargs
784+
785+
@property
786+
def num_timesteps(self):
787+
return self._num_timesteps
788+
770789
@torch.no_grad()
771790
@replace_example_docstring(EXAMPLE_DOC_STRING)
772791
def __call__(
@@ -786,14 +805,15 @@ def __call__(
786805
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
787806
output_type: Optional[str] = "pil",
788807
return_dict: bool = True,
789-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
790-
callback_steps: int = 1,
791808
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
792809
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
793810
guess_mode: bool = False,
794811
control_guidance_start: Union[float, List[float]] = 0.0,
795812
control_guidance_end: Union[float, List[float]] = 1.0,
796813
clip_skip: Optional[int] = None,
814+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
815+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
816+
**kwargs,
797817
):
798818
r"""
799819
The call function to the pipeline for generation.
@@ -868,6 +888,15 @@ def __call__(
868888
clip_skip (`int`, *optional*):
869889
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
870890
the output of the pre-final layer will be used for computing the prompt embeddings.
891+
callback_on_step_end (`Callable`, *optional*):
892+
A function that calls at the end of each denoising steps during the inference. The function is called
893+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
894+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
895+
`callback_on_step_end_tensor_inputs`.
896+
callback_on_step_end_tensor_inputs (`List`, *optional*):
897+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
898+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
899+
`._callback_tensor_inputs` attribute of your pipeine class.
871900
872901
Examples:
873902
@@ -878,6 +907,23 @@ def __call__(
878907
second element is a list of `bool`s indicating whether the corresponding generated image contains
879908
"not-safe-for-work" (nsfw) content.
880909
"""
910+
911+
callback = kwargs.pop("callback", None)
912+
callback_steps = kwargs.pop("callback_steps", None)
913+
914+
if callback is not None:
915+
deprecate(
916+
"callback",
917+
"1.0.0",
918+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
919+
)
920+
if callback_steps is not None:
921+
deprecate(
922+
"callback_steps",
923+
"1.0.0",
924+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
925+
)
926+
881927
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
882928

883929
# align format for control guidance
@@ -903,9 +949,12 @@ def __call__(
903949
controlnet_conditioning_scale,
904950
control_guidance_start,
905951
control_guidance_end,
952+
callback_on_step_end_tensor_inputs,
906953
)
907954

908955
self._guidance_scale = guidance_scale
956+
self._clip_skip = clip_skip
957+
self._cross_attention_kwargs = cross_attention_kwargs
909958

910959
# 2. Define call parameters
911960
if prompt is not None and isinstance(prompt, str):
@@ -929,7 +978,7 @@ def __call__(
929978

930979
# 3. Encode input prompt
931980
text_encoder_lora_scale = (
932-
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
981+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
933982
)
934983
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
935984
prompt,
@@ -940,7 +989,7 @@ def __call__(
940989
prompt_embeds=prompt_embeds,
941990
negative_prompt_embeds=negative_prompt_embeds,
942991
lora_scale=text_encoder_lora_scale,
943-
clip_skip=clip_skip,
992+
clip_skip=self.clip_skip,
944993
)
945994
# For classifier free guidance, we need to do two forward passes.
946995
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -988,6 +1037,7 @@ def __call__(
9881037
# 5. Prepare timesteps
9891038
self.scheduler.set_timesteps(num_inference_steps, device=device)
9901039
timesteps = self.scheduler.timesteps
1040+
self._num_timesteps = len(timesteps)
9911041

9921042
# 6. Prepare latent variables
9931043
num_channels_latents = self.unet.config.in_channels
@@ -1078,7 +1128,7 @@ def __call__(
10781128
t,
10791129
encoder_hidden_states=prompt_embeds,
10801130
timestep_cond=timestep_cond,
1081-
cross_attention_kwargs=cross_attention_kwargs,
1131+
cross_attention_kwargs=self.cross_attention_kwargs,
10821132
down_block_additional_residuals=down_block_res_samples,
10831133
mid_block_additional_residual=mid_block_res_sample,
10841134
return_dict=False,
@@ -1087,11 +1137,21 @@ def __call__(
10871137
# perform guidance
10881138
if self.do_classifier_free_guidance:
10891139
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1090-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1140+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10911141

10921142
# compute the previous noisy sample x_t -> x_t-1
10931143
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
10941144

1145+
if callback_on_step_end is not None:
1146+
callback_kwargs = {}
1147+
for k in callback_on_step_end_tensor_inputs:
1148+
callback_kwargs[k] = locals()[k]
1149+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1150+
1151+
latents = callback_outputs.pop("latents", latents)
1152+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1153+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1154+
10951155
# call the callback, if provided
10961156
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
10971157
progress_bar.update()

0 commit comments

Comments
 (0)