@@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline(
130130 model_cpu_offload_seq = "text_encoder->unet->vae"
131131 _optional_components = ["safety_checker" , "feature_extractor" ]
132132 _exclude_from_cpu_offload = ["safety_checker" ]
133+ _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
133134
134135 def __init__ (
135136 self ,
@@ -485,15 +486,21 @@ def check_inputs(
485486 controlnet_conditioning_scale = 1.0 ,
486487 control_guidance_start = 0.0 ,
487488 control_guidance_end = 1.0 ,
489+ callback_on_step_end_tensor_inputs = None ,
488490 ):
489- if (callback_steps is None ) or (
490- callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
491- ):
491+ if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
492492 raise ValueError (
493493 f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
494494 f" { type (callback_steps )} ."
495495 )
496496
497+ if callback_on_step_end_tensor_inputs is not None and not all (
498+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
499+ ):
500+ raise ValueError (
501+ f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
502+ )
503+
497504 if prompt is not None and prompt_embeds is not None :
498505 raise ValueError (
499506 f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
@@ -760,13 +767,25 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32
760767 def guidance_scale (self ):
761768 return self ._guidance_scale
762769
770+ @property
771+ def clip_skip (self ):
772+ return self ._clip_skip
773+
763774 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
764775 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
765776 # corresponds to doing no classifier free guidance.
766777 @property
767778 def do_classifier_free_guidance (self ):
768779 return self ._guidance_scale > 1 and self .unet .config .time_cond_proj_dim is None
769780
781+ @property
782+ def cross_attention_kwargs (self ):
783+ return self ._cross_attention_kwargs
784+
785+ @property
786+ def num_timesteps (self ):
787+ return self ._num_timesteps
788+
770789 @torch .no_grad ()
771790 @replace_example_docstring (EXAMPLE_DOC_STRING )
772791 def __call__ (
@@ -786,14 +805,15 @@ def __call__(
786805 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
787806 output_type : Optional [str ] = "pil" ,
788807 return_dict : bool = True ,
789- callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
790- callback_steps : int = 1 ,
791808 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
792809 controlnet_conditioning_scale : Union [float , List [float ]] = 1.0 ,
793810 guess_mode : bool = False ,
794811 control_guidance_start : Union [float , List [float ]] = 0.0 ,
795812 control_guidance_end : Union [float , List [float ]] = 1.0 ,
796813 clip_skip : Optional [int ] = None ,
814+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
815+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
816+ ** kwargs ,
797817 ):
798818 r"""
799819 The call function to the pipeline for generation.
@@ -868,6 +888,15 @@ def __call__(
868888 clip_skip (`int`, *optional*):
869889 Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
870890 the output of the pre-final layer will be used for computing the prompt embeddings.
891+ callback_on_step_end (`Callable`, *optional*):
892+ A function that calls at the end of each denoising steps during the inference. The function is called
893+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
894+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
895+ `callback_on_step_end_tensor_inputs`.
896+ callback_on_step_end_tensor_inputs (`List`, *optional*):
897+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
898+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
899+ `._callback_tensor_inputs` attribute of your pipeine class.
871900
872901 Examples:
873902
@@ -878,6 +907,23 @@ def __call__(
878907 second element is a list of `bool`s indicating whether the corresponding generated image contains
879908 "not-safe-for-work" (nsfw) content.
880909 """
910+
911+ callback = kwargs .pop ("callback" , None )
912+ callback_steps = kwargs .pop ("callback_steps" , None )
913+
914+ if callback is not None :
915+ deprecate (
916+ "callback" ,
917+ "1.0.0" ,
918+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
919+ )
920+ if callback_steps is not None :
921+ deprecate (
922+ "callback_steps" ,
923+ "1.0.0" ,
924+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
925+ )
926+
881927 controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
882928
883929 # align format for control guidance
@@ -903,9 +949,12 @@ def __call__(
903949 controlnet_conditioning_scale ,
904950 control_guidance_start ,
905951 control_guidance_end ,
952+ callback_on_step_end_tensor_inputs ,
906953 )
907954
908955 self ._guidance_scale = guidance_scale
956+ self ._clip_skip = clip_skip
957+ self ._cross_attention_kwargs = cross_attention_kwargs
909958
910959 # 2. Define call parameters
911960 if prompt is not None and isinstance (prompt , str ):
@@ -929,7 +978,7 @@ def __call__(
929978
930979 # 3. Encode input prompt
931980 text_encoder_lora_scale = (
932- cross_attention_kwargs .get ("scale" , None ) if cross_attention_kwargs is not None else None
981+ self . cross_attention_kwargs .get ("scale" , None ) if self . cross_attention_kwargs is not None else None
933982 )
934983 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
935984 prompt ,
@@ -940,7 +989,7 @@ def __call__(
940989 prompt_embeds = prompt_embeds ,
941990 negative_prompt_embeds = negative_prompt_embeds ,
942991 lora_scale = text_encoder_lora_scale ,
943- clip_skip = clip_skip ,
992+ clip_skip = self . clip_skip ,
944993 )
945994 # For classifier free guidance, we need to do two forward passes.
946995 # Here we concatenate the unconditional and text embeddings into a single batch
@@ -988,6 +1037,7 @@ def __call__(
9881037 # 5. Prepare timesteps
9891038 self .scheduler .set_timesteps (num_inference_steps , device = device )
9901039 timesteps = self .scheduler .timesteps
1040+ self ._num_timesteps = len (timesteps )
9911041
9921042 # 6. Prepare latent variables
9931043 num_channels_latents = self .unet .config .in_channels
@@ -1078,7 +1128,7 @@ def __call__(
10781128 t ,
10791129 encoder_hidden_states = prompt_embeds ,
10801130 timestep_cond = timestep_cond ,
1081- cross_attention_kwargs = cross_attention_kwargs ,
1131+ cross_attention_kwargs = self . cross_attention_kwargs ,
10821132 down_block_additional_residuals = down_block_res_samples ,
10831133 mid_block_additional_residual = mid_block_res_sample ,
10841134 return_dict = False ,
@@ -1087,11 +1137,21 @@ def __call__(
10871137 # perform guidance
10881138 if self .do_classifier_free_guidance :
10891139 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1090- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1140+ noise_pred = noise_pred_uncond + self . guidance_scale * (noise_pred_text - noise_pred_uncond )
10911141
10921142 # compute the previous noisy sample x_t -> x_t-1
10931143 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
10941144
1145+ if callback_on_step_end is not None :
1146+ callback_kwargs = {}
1147+ for k in callback_on_step_end_tensor_inputs :
1148+ callback_kwargs [k ] = locals ()[k ]
1149+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
1150+
1151+ latents = callback_outputs .pop ("latents" , latents )
1152+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
1153+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
1154+
10951155 # call the callback, if provided
10961156 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
10971157 progress_bar .update ()
0 commit comments