5959"""
6060
6161
62+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
63+ def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
64+ """
65+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
66+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
67+ """
68+ std_text = noise_pred_text .std (dim = list (range (1 , noise_pred_text .ndim )), keepdim = True )
69+ std_cfg = noise_cfg .std (dim = list (range (1 , noise_cfg .ndim )), keepdim = True )
70+ # rescale the results from guidance (fixes overexposure)
71+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg )
72+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
73+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale ) * noise_cfg
74+ return noise_cfg
75+
76+
77+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
78+ def retrieve_timesteps (
79+ scheduler ,
80+ num_inference_steps : Optional [int ] = None ,
81+ device : Optional [Union [str , torch .device ]] = None ,
82+ timesteps : Optional [List [int ]] = None ,
83+ ** kwargs ,
84+ ):
85+ """
86+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
87+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
88+
89+ Args:
90+ scheduler (`SchedulerMixin`):
91+ The scheduler to get timesteps from.
92+ num_inference_steps (`int`):
93+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
94+ `timesteps` must be `None`.
95+ device (`str` or `torch.device`, *optional*):
96+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
97+ timesteps (`List[int]`, *optional*):
98+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
99+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
100+ must be `None`.
101+
102+ Returns:
103+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104+ second element is the number of inference steps.
105+ """
106+ if timesteps is not None :
107+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
108+ if not accepts_timesteps :
109+ raise ValueError (
110+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
111+ f" timestep schedules. Please check whether you are using the correct scheduler."
112+ )
113+ scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
114+ timesteps = scheduler .timesteps
115+ num_inference_steps = len (timesteps )
116+ else :
117+ scheduler .set_timesteps (num_inference_steps , device = device , ** kwargs )
118+ timesteps = scheduler .timesteps
119+ return timesteps , num_inference_steps
120+
121+
62122@dataclass
63123class LDM3DPipelineOutput (BaseOutput ):
64124 """
@@ -125,6 +185,7 @@ class StableDiffusionLDM3DPipeline(
125185 model_cpu_offload_seq = "text_encoder->unet->vae"
126186 _optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
127187 _exclude_from_cpu_offload = ["safety_checker" ]
188+ _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
128189
129190 def __init__ (
130191 self ,
@@ -582,6 +643,66 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
582643 latents = latents * self .scheduler .init_noise_sigma
583644 return latents
584645
646+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
647+ def get_guidance_scale_embedding (self , w , embedding_dim = 512 , dtype = torch .float32 ):
648+ """
649+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
650+
651+ Args:
652+ timesteps (`torch.Tensor`):
653+ generate embedding vectors at these timesteps
654+ embedding_dim (`int`, *optional*, defaults to 512):
655+ dimension of the embeddings to generate
656+ dtype:
657+ data type of the generated embeddings
658+
659+ Returns:
660+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
661+ """
662+ assert len (w .shape ) == 1
663+ w = w * 1000.0
664+
665+ half_dim = embedding_dim // 2
666+ emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
667+ emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
668+ emb = w .to (dtype )[:, None ] * emb [None , :]
669+ emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
670+ if embedding_dim % 2 == 1 : # zero pad
671+ emb = torch .nn .functional .pad (emb , (0 , 1 ))
672+ assert emb .shape == (w .shape [0 ], embedding_dim )
673+ return emb
674+
675+ @property
676+ def guidance_scale (self ):
677+ return self ._guidance_scale
678+
679+ @property
680+ def guidance_rescale (self ):
681+ return self ._guidance_rescale
682+
683+ @property
684+ def clip_skip (self ):
685+ return self ._clip_skip
686+
687+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
688+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
689+ # corresponds to doing no classifier free guidance.
690+ @property
691+ def do_classifier_free_guidance (self ):
692+ return self ._guidance_scale > 1 and self .unet .config .time_cond_proj_dim is None
693+
694+ @property
695+ def cross_attention_kwargs (self ):
696+ return self ._cross_attention_kwargs
697+
698+ @property
699+ def num_timesteps (self ):
700+ return self ._num_timesteps
701+
702+ @property
703+ def interrupt (self ):
704+ return self ._interrupt
705+
585706 @torch .no_grad ()
586707 @replace_example_docstring (EXAMPLE_DOC_STRING )
587708 def __call__ (
@@ -590,6 +711,7 @@ def __call__(
590711 height : Optional [int ] = None ,
591712 width : Optional [int ] = None ,
592713 num_inference_steps : int = 49 ,
714+ timesteps : List [int ] = None ,
593715 guidance_scale : float = 5.0 ,
594716 negative_prompt : Optional [Union [str , List [str ]]] = None ,
595717 num_images_per_prompt : Optional [int ] = 1 ,
@@ -602,10 +724,12 @@ def __call__(
602724 ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
603725 output_type : Optional [str ] = "pil" ,
604726 return_dict : bool = True ,
605- callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
606- callback_steps : int = 1 ,
607727 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
728+ guidance_rescale : float = 0.0 ,
608729 clip_skip : Optional [int ] = None ,
730+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
731+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
732+ ** kwargs ,
609733 ):
610734 r"""
611735 The call function to the pipeline for generation.
@@ -656,18 +780,21 @@ def __call__(
656780 return_dict (`bool`, *optional*, defaults to `True`):
657781 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
658782 plain tuple.
659- callback (`Callable`, *optional*):
660- A function that calls every `callback_steps` steps during inference. The function is called with the
661- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
662- callback_steps (`int`, *optional*, defaults to 1):
663- The frequency at which the `callback` function is called. If not specified, the callback is called at
664- every step.
665783 cross_attention_kwargs (`dict`, *optional*):
666784 A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
667785 [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
668786 clip_skip (`int`, *optional*):
669787 Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
670788 the output of the pre-final layer will be used for computing the prompt embeddings.
789+ callback_on_step_end (`Callable`, *optional*):
790+ A function that calls at the end of each denoising steps during the inference. The function is called
791+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
792+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
793+ `callback_on_step_end_tensor_inputs`.
794+ callback_on_step_end_tensor_inputs (`List`, *optional*):
795+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
796+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
797+ `._callback_tensor_inputs` attribute of your pipeline class.
671798 Examples:
672799
673800 Returns:
@@ -677,6 +804,22 @@ def __call__(
677804 second element is a list of `bool`s indicating whether the corresponding generated image contains
678805 "not-safe-for-work" (nsfw) content.
679806 """
807+ callback = kwargs .pop ("callback" , None )
808+ callback_steps = kwargs .pop ("callback_steps" , None )
809+
810+ if callback is not None :
811+ deprecate (
812+ "callback" ,
813+ "1.0.0" ,
814+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
815+ )
816+ if callback_steps is not None :
817+ deprecate (
818+ "callback_steps" ,
819+ "1.0.0" ,
820+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
821+ )
822+
680823 # 0. Default height and width to unet
681824 height = height or self .unet .config .sample_size * self .vae_scale_factor
682825 width = width or self .unet .config .sample_size * self .vae_scale_factor
@@ -692,8 +835,15 @@ def __call__(
692835 negative_prompt_embeds ,
693836 ip_adapter_image ,
694837 ip_adapter_image_embeds ,
838+ callback_on_step_end_tensor_inputs ,
695839 )
696840
841+ self ._guidance_scale = guidance_scale
842+ self ._guidance_rescale = guidance_rescale
843+ self ._clip_skip = clip_skip
844+ self ._cross_attention_kwargs = cross_attention_kwargs
845+ self ._interrupt = False
846+
697847 # 2. Define call parameters
698848 if prompt is not None and isinstance (prompt , str ):
699849 batch_size = 1
@@ -703,26 +853,22 @@ def __call__(
703853 batch_size = prompt_embeds .shape [0 ]
704854
705855 device = self ._execution_device
706- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
707- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
708- # corresponds to doing no classifier free guidance.
709- do_classifier_free_guidance = guidance_scale > 1.0
710856
711857 if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
712858 image_embeds = self .prepare_ip_adapter_image_embeds (
713859 ip_adapter_image ,
714860 ip_adapter_image_embeds ,
715861 device ,
716862 batch_size * num_images_per_prompt ,
717- do_classifier_free_guidance ,
863+ self . do_classifier_free_guidance ,
718864 )
719865
720866 # 3. Encode input prompt
721867 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
722868 prompt ,
723869 device ,
724870 num_images_per_prompt ,
725- do_classifier_free_guidance ,
871+ self . do_classifier_free_guidance ,
726872 negative_prompt ,
727873 prompt_embeds = prompt_embeds ,
728874 negative_prompt_embeds = negative_prompt_embeds ,
@@ -731,12 +877,11 @@ def __call__(
731877 # For classifier free guidance, we need to do two forward passes.
732878 # Here we concatenate the unconditional and text embeddings into a single batch
733879 # to avoid doing two forward passes
734- if do_classifier_free_guidance :
880+ if self . do_classifier_free_guidance :
735881 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
736882
737883 # 4. Prepare timesteps
738- self .scheduler .set_timesteps (num_inference_steps , device = device )
739- timesteps = self .scheduler .timesteps
884+ timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
740885
741886 # 5. Prepare latent variables
742887 num_channels_latents = self .unet .config .in_channels
@@ -757,32 +902,59 @@ def __call__(
757902 # 6.1 Add image embeds for IP-Adapter
758903 added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
759904
905+ # 6.2 Optionally get Guidance Scale Embedding
906+ timestep_cond = None
907+ if self .unet .config .time_cond_proj_dim is not None :
908+ guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
909+ timestep_cond = self .get_guidance_scale_embedding (
910+ guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
911+ ).to (device = device , dtype = latents .dtype )
912+
760913 # 7. Denoising loop
761914 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
915+ self ._num_timesteps = len (timesteps )
762916 with self .progress_bar (total = num_inference_steps ) as progress_bar :
763917 for i , t in enumerate (timesteps ):
918+ if self .interrupt :
919+ continue
920+
764921 # expand the latents if we are doing classifier free guidance
765- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
922+ latent_model_input = torch .cat ([latents ] * 2 ) if self . do_classifier_free_guidance else latents
766923 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
767924
768925 # predict the noise residual
769926 noise_pred = self .unet (
770927 latent_model_input ,
771928 t ,
772929 encoder_hidden_states = prompt_embeds ,
930+ timestep_cond = timestep_cond ,
773931 cross_attention_kwargs = cross_attention_kwargs ,
774932 added_cond_kwargs = added_cond_kwargs ,
775933 return_dict = False ,
776934 )[0 ]
777935
778936 # perform guidance
779- if do_classifier_free_guidance :
937+ if self . do_classifier_free_guidance :
780938 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
781939 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
782940
941+ if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
942+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
943+ noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
944+
783945 # compute the previous noisy sample x_t -> x_t-1
784946 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
785947
948+ if callback_on_step_end is not None :
949+ callback_kwargs = {}
950+ for k in callback_on_step_end_tensor_inputs :
951+ callback_kwargs [k ] = locals ()[k ]
952+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
953+
954+ latents = callback_outputs .pop ("latents" , latents )
955+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
956+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
957+
786958 # call the callback, if provided
787959 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
788960 progress_bar .update ()
0 commit comments