Skip to content

Commit a812fb6

Browse files
19and99sayakpaul
andauthored
Text2video zero refinements (#3733)
* fix docs typos. add frame_ids argument to text2video-zero pipeline call * make style && make quality * add support of pytorch 2.0 scaled_dot_product_attention for CrossFrameAttnProcessor * add chunk-by-chunk processing to text2video-zero docs * make style && make quality * Update docs/source/en/api/pipelines/text_to_video_zero.mdx Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent f46b22b commit a812fb6

File tree

2 files changed

+130
-9
lines changed

2 files changed

+130
-9
lines changed

docs/source/en/api/pipelines/text_to_video_zero.mdx

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,41 @@ You can change these parameters in the pipeline call:
8080
* Video length:
8181
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
8282

83+
We an also generate longer videos by doing the processing in a chunk-by-chunk manner:
84+
```python
85+
import torch
86+
import imageio
87+
from diffusers import TextToVideoZeroPipeline
88+
import numpy as np
89+
90+
model_id = "runwayml/stable-diffusion-v1-5"
91+
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
92+
seed = 0
93+
video_length = 8
94+
chunk_size = 4
95+
prompt = "A panda is playing guitar on times square"
96+
97+
# Generate the video chunk-by-chunk
98+
result = []
99+
chunk_ids = np.arange(0, video_length, chunk_size - 1)
100+
generator = torch.Generator(device="cuda")
101+
for i in range(len(chunk_ids)):
102+
print(f"Processing chunk {i + 1} / {len(chunk_ids)}")
103+
ch_start = chunk_ids[i]
104+
ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
105+
# Attach the first frame for Cross Frame Attention
106+
frame_ids = [0] + list(range(ch_start, ch_end))
107+
# Fix the seed for the temporal consistency
108+
generator.manual_seed(seed)
109+
output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)
110+
result.append(output.images[1:])
111+
112+
# Concatenate chunks and save
113+
result = np.concatenate(result)
114+
result = [(r * 255).astype("uint8") for r in result]
115+
imageio.mimsave("video.mp4", result, fps=4)
116+
```
117+
83118

84119
### Text-To-Video with Pose Control
85120
To generate a video from prompt with additional pose control
@@ -202,7 +237,7 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
202237

203238
reader = imageio.get_reader(video_path, "ffmpeg")
204239
frame_count = 8
205-
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
240+
canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
206241
```
207242

208243
3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
@@ -223,10 +258,10 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
223258
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
224259

225260
# fix latents for all frames
226-
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
261+
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)
227262

228263
prompt = "oil painting of a beautiful girl avatar style"
229-
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
264+
result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images
230265
imageio.mimsave("video.mp4", result, fps=4)
231266
```
232267

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def rearrange_4(tensor):
3838

3939
class CrossFrameAttnProcessor:
4040
"""
41-
Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
41+
Cross frame attention processor. Each frame attends the first frame.
4242
4343
Args:
4444
batch_size: The number that represents actual batch size, other than the frames.
45-
For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
46-
equal to 2, due to classifier-free guidance.
45+
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
46+
2, due to classifier-free guidance.
4747
"""
4848

4949
def __init__(self, batch_size=2):
@@ -63,7 +63,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
6363
key = attn.to_k(encoder_hidden_states)
6464
value = attn.to_v(encoder_hidden_states)
6565

66-
# Sparse Attention
66+
# Cross Frame Attention
6767
if not is_cross_attention:
6868
video_length = key.size()[0] // self.batch_size
6969
first_frame_index = [0] * video_length
@@ -95,6 +95,81 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
9595
return hidden_states
9696

9797

98+
class CrossFrameAttnProcessor2_0:
99+
"""
100+
Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
101+
102+
Args:
103+
batch_size: The number that represents actual batch size, other than the frames.
104+
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
105+
2, due to classifier-free guidance.
106+
"""
107+
108+
def __init__(self, batch_size=2):
109+
if not hasattr(F, "scaled_dot_product_attention"):
110+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
111+
self.batch_size = batch_size
112+
113+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
114+
batch_size, sequence_length, _ = (
115+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
116+
)
117+
inner_dim = hidden_states.shape[-1]
118+
119+
if attention_mask is not None:
120+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
121+
# scaled_dot_product_attention expects attention_mask shape to be
122+
# (batch, heads, source_length, target_length)
123+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
124+
125+
query = attn.to_q(hidden_states)
126+
127+
is_cross_attention = encoder_hidden_states is not None
128+
if encoder_hidden_states is None:
129+
encoder_hidden_states = hidden_states
130+
elif attn.norm_cross:
131+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
132+
133+
key = attn.to_k(encoder_hidden_states)
134+
value = attn.to_v(encoder_hidden_states)
135+
136+
# Cross Frame Attention
137+
if not is_cross_attention:
138+
video_length = key.size()[0] // self.batch_size
139+
first_frame_index = [0] * video_length
140+
141+
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
142+
key = rearrange_3(key, video_length)
143+
key = key[:, first_frame_index]
144+
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
145+
value = rearrange_3(value, video_length)
146+
value = value[:, first_frame_index]
147+
148+
# rearrange back to original shape
149+
key = rearrange_4(key)
150+
value = rearrange_4(value)
151+
152+
head_dim = inner_dim // attn.heads
153+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
154+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
155+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
156+
157+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
158+
# TODO: add support for attn.scale when we move to Torch 2.1
159+
hidden_states = F.scaled_dot_product_attention(
160+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
161+
)
162+
163+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
164+
hidden_states = hidden_states.to(query.dtype)
165+
166+
# linear proj
167+
hidden_states = attn.to_out[0](hidden_states)
168+
# dropout
169+
hidden_states = attn.to_out[1](hidden_states)
170+
return hidden_states
171+
172+
98173
@dataclass
99174
class TextToVideoPipelineOutput(BaseOutput):
100175
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -227,7 +302,12 @@ def __init__(
227302
super().__init__(
228303
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
229304
)
230-
self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
305+
processor = (
306+
CrossFrameAttnProcessor2_0(batch_size=2)
307+
if hasattr(F, "scaled_dot_product_attention")
308+
else CrossFrameAttnProcessor(batch_size=2)
309+
)
310+
self.unet.set_attn_processor(processor)
231311

232312
def forward_loop(self, x_t0, t0, t1, generator):
233313
"""
@@ -338,6 +418,7 @@ def __call__(
338418
callback_steps: Optional[int] = 1,
339419
t0: int = 44,
340420
t1: int = 47,
421+
frame_ids: Optional[List[int]] = None,
341422
):
342423
"""
343424
Function invoked when calling the pipeline for generation.
@@ -399,6 +480,9 @@ def __call__(
399480
t1 (`int`, *optional*, defaults to 47):
400481
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
401482
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
483+
frame_ids (`List[int]`, *optional*):
484+
Indexes of the frames that are being generated. This is used when generating longer videos
485+
chunk-by-chunk.
402486
403487
Returns:
404488
[`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
@@ -407,7 +491,9 @@ def __call__(
407491
likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
408492
"""
409493
assert video_length > 0
410-
frame_ids = list(range(video_length))
494+
if frame_ids is None:
495+
frame_ids = list(range(video_length))
496+
assert len(frame_ids) == video_length
411497

412498
assert num_videos_per_prompt == 1
413499

0 commit comments

Comments
 (0)