Skip to content

Commit f2da6a2

Browse files
patrickvonplatendonhardman
authored andcommitted
[Text-to-Video] Clean up pipeline (huggingface#6213)
* make style * make style * make style * make style
1 parent 66e96d5 commit f2da6a2

File tree

3 files changed

+800
-19
lines changed

3 files changed

+800
-19
lines changed

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 331 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import inspect
23
from dataclasses import dataclass
34
from typing import Callable, List, Optional, Union
45

@@ -9,11 +10,18 @@
910
from torch.nn.functional import grid_sample
1011
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
1112

12-
from diffusers.models import AutoencoderKL, UNet2DConditionModel
13-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
14-
from diffusers.schedulers import KarrasDiffusionSchedulers
15-
from diffusers.utils import BaseOutput
16-
from diffusers.utils.torch_utils import randn_tensor
13+
from ...image_processor import VaeImageProcessor
14+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
15+
from ...models import AutoencoderKL, UNet2DConditionModel
16+
from ...models.lora import adjust_lora_scale_text_encoder
17+
from ...schedulers import KarrasDiffusionSchedulers
18+
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
19+
from ...utils.torch_utils import randn_tensor
20+
from ..pipeline_utils import DiffusionPipeline
21+
from ..stable_diffusion import StableDiffusionSafetyChecker
22+
23+
24+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1725

1826

1927
def rearrange_0(tensor, f):
@@ -273,7 +281,7 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
273281
return warped_latents
274282

275283

276-
class TextToVideoZeroPipeline(StableDiffusionPipeline):
284+
class TextToVideoZeroPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
277285
r"""
278286
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
279287
@@ -311,8 +319,15 @@ def __init__(
311319
feature_extractor: CLIPImageProcessor,
312320
requires_safety_checker: bool = True,
313321
):
314-
super().__init__(
315-
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
322+
super().__init__()
323+
self.register_modules(
324+
vae=vae,
325+
text_encoder=text_encoder,
326+
tokenizer=tokenizer,
327+
unet=unet,
328+
scheduler=scheduler,
329+
safety_checker=safety_checker,
330+
feature_extractor=feature_extractor,
316331
)
317332
processor = (
318333
CrossFrameAttnProcessor2_0(batch_size=2)
@@ -321,6 +336,18 @@ def __init__(
321336
)
322337
self.unet.set_attn_processor(processor)
323338

339+
if safety_checker is None and requires_safety_checker:
340+
logger.warning(
341+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
342+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
343+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
344+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
345+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
346+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
347+
)
348+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
349+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
350+
324351
def forward_loop(self, x_t0, t0, t1, generator):
325352
"""
326353
Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
@@ -420,6 +447,77 @@ def backward_loop(
420447
callback(step_idx, t, latents)
421448
return latents.clone().detach()
422449

450+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
451+
def check_inputs(
452+
self,
453+
prompt,
454+
height,
455+
width,
456+
callback_steps,
457+
negative_prompt=None,
458+
prompt_embeds=None,
459+
negative_prompt_embeds=None,
460+
callback_on_step_end_tensor_inputs=None,
461+
):
462+
if height % 8 != 0 or width % 8 != 0:
463+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
464+
465+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
466+
raise ValueError(
467+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
468+
f" {type(callback_steps)}."
469+
)
470+
if callback_on_step_end_tensor_inputs is not None and not all(
471+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
472+
):
473+
raise ValueError(
474+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
475+
)
476+
477+
if prompt is not None and prompt_embeds is not None:
478+
raise ValueError(
479+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
480+
" only forward one of the two."
481+
)
482+
elif prompt is None and prompt_embeds is None:
483+
raise ValueError(
484+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
485+
)
486+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
487+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
488+
489+
if negative_prompt is not None and negative_prompt_embeds is not None:
490+
raise ValueError(
491+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
492+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
493+
)
494+
495+
if prompt_embeds is not None and negative_prompt_embeds is not None:
496+
if prompt_embeds.shape != negative_prompt_embeds.shape:
497+
raise ValueError(
498+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
499+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
500+
f" {negative_prompt_embeds.shape}."
501+
)
502+
503+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
504+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
505+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
506+
if isinstance(generator, list) and len(generator) != batch_size:
507+
raise ValueError(
508+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
509+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
510+
)
511+
512+
if latents is None:
513+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
514+
else:
515+
latents = latents.to(device)
516+
517+
# scale the initial noise by the standard deviation required by the scheduler
518+
latents = latents * self.scheduler.init_noise_sigma
519+
return latents
520+
423521
@torch.no_grad()
424522
def __call__(
425523
self,
@@ -539,9 +637,10 @@ def __call__(
539637
do_classifier_free_guidance = guidance_scale > 1.0
540638

541639
# Encode input prompt
542-
prompt_embeds = self._encode_prompt(
640+
prompt_embeds_tuple = self.encode_prompt(
543641
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
544642
)
643+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
545644

546645
# Prepare timesteps
547646
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -645,3 +744,226 @@ def __call__(
645744
return (image, has_nsfw_concept)
646745

647746
return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
747+
748+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
749+
def run_safety_checker(self, image, device, dtype):
750+
if self.safety_checker is None:
751+
has_nsfw_concept = None
752+
else:
753+
if torch.is_tensor(image):
754+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
755+
else:
756+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
757+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
758+
image, has_nsfw_concept = self.safety_checker(
759+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
760+
)
761+
return image, has_nsfw_concept
762+
763+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
764+
def prepare_extra_step_kwargs(self, generator, eta):
765+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
766+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
767+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
768+
# and should be between [0, 1]
769+
770+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
771+
extra_step_kwargs = {}
772+
if accepts_eta:
773+
extra_step_kwargs["eta"] = eta
774+
775+
# check if the scheduler accepts generator
776+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
777+
if accepts_generator:
778+
extra_step_kwargs["generator"] = generator
779+
return extra_step_kwargs
780+
781+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
782+
def encode_prompt(
783+
self,
784+
prompt,
785+
device,
786+
num_images_per_prompt,
787+
do_classifier_free_guidance,
788+
negative_prompt=None,
789+
prompt_embeds: Optional[torch.FloatTensor] = None,
790+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
791+
lora_scale: Optional[float] = None,
792+
clip_skip: Optional[int] = None,
793+
):
794+
r"""
795+
Encodes the prompt into text encoder hidden states.
796+
797+
Args:
798+
prompt (`str` or `List[str]`, *optional*):
799+
prompt to be encoded
800+
device: (`torch.device`):
801+
torch device
802+
num_images_per_prompt (`int`):
803+
number of images that should be generated per prompt
804+
do_classifier_free_guidance (`bool`):
805+
whether to use classifier free guidance or not
806+
negative_prompt (`str` or `List[str]`, *optional*):
807+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
808+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
809+
less than `1`).
810+
prompt_embeds (`torch.FloatTensor`, *optional*):
811+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
812+
provided, text embeddings will be generated from `prompt` input argument.
813+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
814+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
815+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
816+
argument.
817+
lora_scale (`float`, *optional*):
818+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
819+
clip_skip (`int`, *optional*):
820+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
821+
the output of the pre-final layer will be used for computing the prompt embeddings.
822+
"""
823+
# set lora scale so that monkey patched LoRA
824+
# function of text encoder can correctly access it
825+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
826+
self._lora_scale = lora_scale
827+
828+
# dynamically adjust the LoRA scale
829+
if not USE_PEFT_BACKEND:
830+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
831+
else:
832+
scale_lora_layers(self.text_encoder, lora_scale)
833+
834+
if prompt is not None and isinstance(prompt, str):
835+
batch_size = 1
836+
elif prompt is not None and isinstance(prompt, list):
837+
batch_size = len(prompt)
838+
else:
839+
batch_size = prompt_embeds.shape[0]
840+
841+
if prompt_embeds is None:
842+
# textual inversion: procecss multi-vector tokens if necessary
843+
if isinstance(self, TextualInversionLoaderMixin):
844+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
845+
846+
text_inputs = self.tokenizer(
847+
prompt,
848+
padding="max_length",
849+
max_length=self.tokenizer.model_max_length,
850+
truncation=True,
851+
return_tensors="pt",
852+
)
853+
text_input_ids = text_inputs.input_ids
854+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
855+
856+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
857+
text_input_ids, untruncated_ids
858+
):
859+
removed_text = self.tokenizer.batch_decode(
860+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
861+
)
862+
logger.warning(
863+
"The following part of your input was truncated because CLIP can only handle sequences up to"
864+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
865+
)
866+
867+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
868+
attention_mask = text_inputs.attention_mask.to(device)
869+
else:
870+
attention_mask = None
871+
872+
if clip_skip is None:
873+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
874+
prompt_embeds = prompt_embeds[0]
875+
else:
876+
prompt_embeds = self.text_encoder(
877+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
878+
)
879+
# Access the `hidden_states` first, that contains a tuple of
880+
# all the hidden states from the encoder layers. Then index into
881+
# the tuple to access the hidden states from the desired layer.
882+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
883+
# We also need to apply the final LayerNorm here to not mess with the
884+
# representations. The `last_hidden_states` that we typically use for
885+
# obtaining the final prompt representations passes through the LayerNorm
886+
# layer.
887+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
888+
889+
if self.text_encoder is not None:
890+
prompt_embeds_dtype = self.text_encoder.dtype
891+
elif self.unet is not None:
892+
prompt_embeds_dtype = self.unet.dtype
893+
else:
894+
prompt_embeds_dtype = prompt_embeds.dtype
895+
896+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
897+
898+
bs_embed, seq_len, _ = prompt_embeds.shape
899+
# duplicate text embeddings for each generation per prompt, using mps friendly method
900+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
901+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
902+
903+
# get unconditional embeddings for classifier free guidance
904+
if do_classifier_free_guidance and negative_prompt_embeds is None:
905+
uncond_tokens: List[str]
906+
if negative_prompt is None:
907+
uncond_tokens = [""] * batch_size
908+
elif prompt is not None and type(prompt) is not type(negative_prompt):
909+
raise TypeError(
910+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
911+
f" {type(prompt)}."
912+
)
913+
elif isinstance(negative_prompt, str):
914+
uncond_tokens = [negative_prompt]
915+
elif batch_size != len(negative_prompt):
916+
raise ValueError(
917+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
918+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
919+
" the batch size of `prompt`."
920+
)
921+
else:
922+
uncond_tokens = negative_prompt
923+
924+
# textual inversion: procecss multi-vector tokens if necessary
925+
if isinstance(self, TextualInversionLoaderMixin):
926+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
927+
928+
max_length = prompt_embeds.shape[1]
929+
uncond_input = self.tokenizer(
930+
uncond_tokens,
931+
padding="max_length",
932+
max_length=max_length,
933+
truncation=True,
934+
return_tensors="pt",
935+
)
936+
937+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
938+
attention_mask = uncond_input.attention_mask.to(device)
939+
else:
940+
attention_mask = None
941+
942+
negative_prompt_embeds = self.text_encoder(
943+
uncond_input.input_ids.to(device),
944+
attention_mask=attention_mask,
945+
)
946+
negative_prompt_embeds = negative_prompt_embeds[0]
947+
948+
if do_classifier_free_guidance:
949+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
950+
seq_len = negative_prompt_embeds.shape[1]
951+
952+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
953+
954+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
955+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
956+
957+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
958+
# Retrieve the original scale by scaling back the LoRA layers
959+
unscale_lora_layers(self.text_encoder, lora_scale)
960+
961+
return prompt_embeds, negative_prompt_embeds
962+
963+
def decode_latents(self, latents):
964+
latents = 1 / self.vae.config.scaling_factor * latents
965+
image = self.vae.decode(latents, return_dict=False)[0]
966+
image = (image / 2 + 0.5).clamp(0, 1)
967+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
968+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
969+
return image

0 commit comments

Comments
 (0)