11import copy
2+ import inspect
23from dataclasses import dataclass
34from typing import Callable , List , Optional , Union
45
910from torch .nn .functional import grid_sample
1011from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
1112
12- from diffusers .models import AutoencoderKL , UNet2DConditionModel
13- from diffusers .pipelines .stable_diffusion import StableDiffusionPipeline , StableDiffusionSafetyChecker
14- from diffusers .schedulers import KarrasDiffusionSchedulers
15- from diffusers .utils import BaseOutput
16- from diffusers .utils .torch_utils import randn_tensor
13+ from ...image_processor import VaeImageProcessor
14+ from ...loaders import LoraLoaderMixin , TextualInversionLoaderMixin
15+ from ...models import AutoencoderKL , UNet2DConditionModel
16+ from ...models .lora import adjust_lora_scale_text_encoder
17+ from ...schedulers import KarrasDiffusionSchedulers
18+ from ...utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
19+ from ...utils .torch_utils import randn_tensor
20+ from ..pipeline_utils import DiffusionPipeline
21+ from ..stable_diffusion import StableDiffusionSafetyChecker
22+
23+
24+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1725
1826
1927def rearrange_0 (tensor , f ):
@@ -273,7 +281,7 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
273281 return warped_latents
274282
275283
276- class TextToVideoZeroPipeline (StableDiffusionPipeline ):
284+ class TextToVideoZeroPipeline (DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin ):
277285 r"""
278286 Pipeline for zero-shot text-to-video generation using Stable Diffusion.
279287
@@ -311,8 +319,15 @@ def __init__(
311319 feature_extractor : CLIPImageProcessor ,
312320 requires_safety_checker : bool = True ,
313321 ):
314- super ().__init__ (
315- vae , text_encoder , tokenizer , unet , scheduler , safety_checker , feature_extractor , requires_safety_checker
322+ super ().__init__ ()
323+ self .register_modules (
324+ vae = vae ,
325+ text_encoder = text_encoder ,
326+ tokenizer = tokenizer ,
327+ unet = unet ,
328+ scheduler = scheduler ,
329+ safety_checker = safety_checker ,
330+ feature_extractor = feature_extractor ,
316331 )
317332 processor = (
318333 CrossFrameAttnProcessor2_0 (batch_size = 2 )
@@ -321,6 +336,18 @@ def __init__(
321336 )
322337 self .unet .set_attn_processor (processor )
323338
339+ if safety_checker is None and requires_safety_checker :
340+ logger .warning (
341+ f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
342+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
343+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
344+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
345+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
346+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
347+ )
348+ self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
349+ self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
350+
324351 def forward_loop (self , x_t0 , t0 , t1 , generator ):
325352 """
326353 Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
@@ -420,6 +447,77 @@ def backward_loop(
420447 callback (step_idx , t , latents )
421448 return latents .clone ().detach ()
422449
450+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
451+ def check_inputs (
452+ self ,
453+ prompt ,
454+ height ,
455+ width ,
456+ callback_steps ,
457+ negative_prompt = None ,
458+ prompt_embeds = None ,
459+ negative_prompt_embeds = None ,
460+ callback_on_step_end_tensor_inputs = None ,
461+ ):
462+ if height % 8 != 0 or width % 8 != 0 :
463+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
464+
465+ if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
466+ raise ValueError (
467+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
468+ f" { type (callback_steps )} ."
469+ )
470+ if callback_on_step_end_tensor_inputs is not None and not all (
471+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
472+ ):
473+ raise ValueError (
474+ f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
475+ )
476+
477+ if prompt is not None and prompt_embeds is not None :
478+ raise ValueError (
479+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
480+ " only forward one of the two."
481+ )
482+ elif prompt is None and prompt_embeds is None :
483+ raise ValueError (
484+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
485+ )
486+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
487+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
488+
489+ if negative_prompt is not None and negative_prompt_embeds is not None :
490+ raise ValueError (
491+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
492+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
493+ )
494+
495+ if prompt_embeds is not None and negative_prompt_embeds is not None :
496+ if prompt_embeds .shape != negative_prompt_embeds .shape :
497+ raise ValueError (
498+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
499+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
500+ f" { negative_prompt_embeds .shape } ."
501+ )
502+
503+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
504+ def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
505+ shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
506+ if isinstance (generator , list ) and len (generator ) != batch_size :
507+ raise ValueError (
508+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
509+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
510+ )
511+
512+ if latents is None :
513+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
514+ else :
515+ latents = latents .to (device )
516+
517+ # scale the initial noise by the standard deviation required by the scheduler
518+ latents = latents * self .scheduler .init_noise_sigma
519+ return latents
520+
423521 @torch .no_grad ()
424522 def __call__ (
425523 self ,
@@ -539,9 +637,10 @@ def __call__(
539637 do_classifier_free_guidance = guidance_scale > 1.0
540638
541639 # Encode input prompt
542- prompt_embeds = self ._encode_prompt (
640+ prompt_embeds_tuple = self .encode_prompt (
543641 prompt , device , num_videos_per_prompt , do_classifier_free_guidance , negative_prompt
544642 )
643+ prompt_embeds = torch .cat ([prompt_embeds_tuple [1 ], prompt_embeds_tuple [0 ]])
545644
546645 # Prepare timesteps
547646 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -645,3 +744,226 @@ def __call__(
645744 return (image , has_nsfw_concept )
646745
647746 return TextToVideoPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
747+
748+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
749+ def run_safety_checker (self , image , device , dtype ):
750+ if self .safety_checker is None :
751+ has_nsfw_concept = None
752+ else :
753+ if torch .is_tensor (image ):
754+ feature_extractor_input = self .image_processor .postprocess (image , output_type = "pil" )
755+ else :
756+ feature_extractor_input = self .image_processor .numpy_to_pil (image )
757+ safety_checker_input = self .feature_extractor (feature_extractor_input , return_tensors = "pt" ).to (device )
758+ image , has_nsfw_concept = self .safety_checker (
759+ images = image , clip_input = safety_checker_input .pixel_values .to (dtype )
760+ )
761+ return image , has_nsfw_concept
762+
763+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
764+ def prepare_extra_step_kwargs (self , generator , eta ):
765+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
766+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
767+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
768+ # and should be between [0, 1]
769+
770+ accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
771+ extra_step_kwargs = {}
772+ if accepts_eta :
773+ extra_step_kwargs ["eta" ] = eta
774+
775+ # check if the scheduler accepts generator
776+ accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
777+ if accepts_generator :
778+ extra_step_kwargs ["generator" ] = generator
779+ return extra_step_kwargs
780+
781+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
782+ def encode_prompt (
783+ self ,
784+ prompt ,
785+ device ,
786+ num_images_per_prompt ,
787+ do_classifier_free_guidance ,
788+ negative_prompt = None ,
789+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
790+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
791+ lora_scale : Optional [float ] = None ,
792+ clip_skip : Optional [int ] = None ,
793+ ):
794+ r"""
795+ Encodes the prompt into text encoder hidden states.
796+
797+ Args:
798+ prompt (`str` or `List[str]`, *optional*):
799+ prompt to be encoded
800+ device: (`torch.device`):
801+ torch device
802+ num_images_per_prompt (`int`):
803+ number of images that should be generated per prompt
804+ do_classifier_free_guidance (`bool`):
805+ whether to use classifier free guidance or not
806+ negative_prompt (`str` or `List[str]`, *optional*):
807+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
808+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
809+ less than `1`).
810+ prompt_embeds (`torch.FloatTensor`, *optional*):
811+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
812+ provided, text embeddings will be generated from `prompt` input argument.
813+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
814+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
815+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
816+ argument.
817+ lora_scale (`float`, *optional*):
818+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
819+ clip_skip (`int`, *optional*):
820+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
821+ the output of the pre-final layer will be used for computing the prompt embeddings.
822+ """
823+ # set lora scale so that monkey patched LoRA
824+ # function of text encoder can correctly access it
825+ if lora_scale is not None and isinstance (self , LoraLoaderMixin ):
826+ self ._lora_scale = lora_scale
827+
828+ # dynamically adjust the LoRA scale
829+ if not USE_PEFT_BACKEND :
830+ adjust_lora_scale_text_encoder (self .text_encoder , lora_scale )
831+ else :
832+ scale_lora_layers (self .text_encoder , lora_scale )
833+
834+ if prompt is not None and isinstance (prompt , str ):
835+ batch_size = 1
836+ elif prompt is not None and isinstance (prompt , list ):
837+ batch_size = len (prompt )
838+ else :
839+ batch_size = prompt_embeds .shape [0 ]
840+
841+ if prompt_embeds is None :
842+ # textual inversion: procecss multi-vector tokens if necessary
843+ if isinstance (self , TextualInversionLoaderMixin ):
844+ prompt = self .maybe_convert_prompt (prompt , self .tokenizer )
845+
846+ text_inputs = self .tokenizer (
847+ prompt ,
848+ padding = "max_length" ,
849+ max_length = self .tokenizer .model_max_length ,
850+ truncation = True ,
851+ return_tensors = "pt" ,
852+ )
853+ text_input_ids = text_inputs .input_ids
854+ untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
855+
856+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
857+ text_input_ids , untruncated_ids
858+ ):
859+ removed_text = self .tokenizer .batch_decode (
860+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
861+ )
862+ logger .warning (
863+ "The following part of your input was truncated because CLIP can only handle sequences up to"
864+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
865+ )
866+
867+ if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
868+ attention_mask = text_inputs .attention_mask .to (device )
869+ else :
870+ attention_mask = None
871+
872+ if clip_skip is None :
873+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
874+ prompt_embeds = prompt_embeds [0 ]
875+ else :
876+ prompt_embeds = self .text_encoder (
877+ text_input_ids .to (device ), attention_mask = attention_mask , output_hidden_states = True
878+ )
879+ # Access the `hidden_states` first, that contains a tuple of
880+ # all the hidden states from the encoder layers. Then index into
881+ # the tuple to access the hidden states from the desired layer.
882+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
883+ # We also need to apply the final LayerNorm here to not mess with the
884+ # representations. The `last_hidden_states` that we typically use for
885+ # obtaining the final prompt representations passes through the LayerNorm
886+ # layer.
887+ prompt_embeds = self .text_encoder .text_model .final_layer_norm (prompt_embeds )
888+
889+ if self .text_encoder is not None :
890+ prompt_embeds_dtype = self .text_encoder .dtype
891+ elif self .unet is not None :
892+ prompt_embeds_dtype = self .unet .dtype
893+ else :
894+ prompt_embeds_dtype = prompt_embeds .dtype
895+
896+ prompt_embeds = prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
897+
898+ bs_embed , seq_len , _ = prompt_embeds .shape
899+ # duplicate text embeddings for each generation per prompt, using mps friendly method
900+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
901+ prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
902+
903+ # get unconditional embeddings for classifier free guidance
904+ if do_classifier_free_guidance and negative_prompt_embeds is None :
905+ uncond_tokens : List [str ]
906+ if negative_prompt is None :
907+ uncond_tokens = ["" ] * batch_size
908+ elif prompt is not None and type (prompt ) is not type (negative_prompt ):
909+ raise TypeError (
910+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
911+ f" { type (prompt )} ."
912+ )
913+ elif isinstance (negative_prompt , str ):
914+ uncond_tokens = [negative_prompt ]
915+ elif batch_size != len (negative_prompt ):
916+ raise ValueError (
917+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
918+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
919+ " the batch size of `prompt`."
920+ )
921+ else :
922+ uncond_tokens = negative_prompt
923+
924+ # textual inversion: procecss multi-vector tokens if necessary
925+ if isinstance (self , TextualInversionLoaderMixin ):
926+ uncond_tokens = self .maybe_convert_prompt (uncond_tokens , self .tokenizer )
927+
928+ max_length = prompt_embeds .shape [1 ]
929+ uncond_input = self .tokenizer (
930+ uncond_tokens ,
931+ padding = "max_length" ,
932+ max_length = max_length ,
933+ truncation = True ,
934+ return_tensors = "pt" ,
935+ )
936+
937+ if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
938+ attention_mask = uncond_input .attention_mask .to (device )
939+ else :
940+ attention_mask = None
941+
942+ negative_prompt_embeds = self .text_encoder (
943+ uncond_input .input_ids .to (device ),
944+ attention_mask = attention_mask ,
945+ )
946+ negative_prompt_embeds = negative_prompt_embeds [0 ]
947+
948+ if do_classifier_free_guidance :
949+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
950+ seq_len = negative_prompt_embeds .shape [1 ]
951+
952+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
953+
954+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
955+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
956+
957+ if isinstance (self , LoraLoaderMixin ) and USE_PEFT_BACKEND :
958+ # Retrieve the original scale by scaling back the LoRA layers
959+ unscale_lora_layers (self .text_encoder , lora_scale )
960+
961+ return prompt_embeds , negative_prompt_embeds
962+
963+ def decode_latents (self , latents ):
964+ latents = 1 / self .vae .config .scaling_factor * latents
965+ image = self .vae .decode (latents , return_dict = False )[0 ]
966+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
967+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
968+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
969+ return image
0 commit comments