@@ -38,12 +38,12 @@ def rearrange_4(tensor):
3838
3939class CrossFrameAttnProcessor :
4040 """
41- Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
41+ Cross frame attention processor. Each frame attends the first frame.
4242
4343 Args:
4444 batch_size: The number that represents actual batch size, other than the frames.
45- For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
46- equal to 2, due to classifier-free guidance.
45+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
46+ 2, due to classifier-free guidance.
4747 """
4848
4949 def __init__ (self , batch_size = 2 ):
@@ -63,7 +63,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
6363 key = attn .to_k (encoder_hidden_states )
6464 value = attn .to_v (encoder_hidden_states )
6565
66- # Sparse Attention
66+ # Cross Frame Attention
6767 if not is_cross_attention :
6868 video_length = key .size ()[0 ] // self .batch_size
6969 first_frame_index = [0 ] * video_length
@@ -95,6 +95,81 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
9595 return hidden_states
9696
9797
98+ class CrossFrameAttnProcessor2_0 :
99+ """
100+ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
101+
102+ Args:
103+ batch_size: The number that represents actual batch size, other than the frames.
104+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
105+ 2, due to classifier-free guidance.
106+ """
107+
108+ def __init__ (self , batch_size = 2 ):
109+ if not hasattr (F , "scaled_dot_product_attention" ):
110+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
111+ self .batch_size = batch_size
112+
113+ def __call__ (self , attn , hidden_states , encoder_hidden_states = None , attention_mask = None ):
114+ batch_size , sequence_length , _ = (
115+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
116+ )
117+ inner_dim = hidden_states .shape [- 1 ]
118+
119+ if attention_mask is not None :
120+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
121+ # scaled_dot_product_attention expects attention_mask shape to be
122+ # (batch, heads, source_length, target_length)
123+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
124+
125+ query = attn .to_q (hidden_states )
126+
127+ is_cross_attention = encoder_hidden_states is not None
128+ if encoder_hidden_states is None :
129+ encoder_hidden_states = hidden_states
130+ elif attn .norm_cross :
131+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
132+
133+ key = attn .to_k (encoder_hidden_states )
134+ value = attn .to_v (encoder_hidden_states )
135+
136+ # Cross Frame Attention
137+ if not is_cross_attention :
138+ video_length = key .size ()[0 ] // self .batch_size
139+ first_frame_index = [0 ] * video_length
140+
141+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
142+ key = rearrange_3 (key , video_length )
143+ key = key [:, first_frame_index ]
144+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
145+ value = rearrange_3 (value , video_length )
146+ value = value [:, first_frame_index ]
147+
148+ # rearrange back to original shape
149+ key = rearrange_4 (key )
150+ value = rearrange_4 (value )
151+
152+ head_dim = inner_dim // attn .heads
153+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
154+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
155+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
156+
157+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
158+ # TODO: add support for attn.scale when we move to Torch 2.1
159+ hidden_states = F .scaled_dot_product_attention (
160+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
161+ )
162+
163+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
164+ hidden_states = hidden_states .to (query .dtype )
165+
166+ # linear proj
167+ hidden_states = attn .to_out [0 ](hidden_states )
168+ # dropout
169+ hidden_states = attn .to_out [1 ](hidden_states )
170+ return hidden_states
171+
172+
98173@dataclass
99174class TextToVideoPipelineOutput (BaseOutput ):
100175 images : Union [List [PIL .Image .Image ], np .ndarray ]
@@ -227,7 +302,12 @@ def __init__(
227302 super ().__init__ (
228303 vae , text_encoder , tokenizer , unet , scheduler , safety_checker , feature_extractor , requires_safety_checker
229304 )
230- self .unet .set_attn_processor (CrossFrameAttnProcessor (batch_size = 2 ))
305+ processor = (
306+ CrossFrameAttnProcessor2_0 (batch_size = 2 )
307+ if hasattr (F , "scaled_dot_product_attention" )
308+ else CrossFrameAttnProcessor (batch_size = 2 )
309+ )
310+ self .unet .set_attn_processor (processor )
231311
232312 def forward_loop (self , x_t0 , t0 , t1 , generator ):
233313 """
@@ -338,6 +418,7 @@ def __call__(
338418 callback_steps : Optional [int ] = 1 ,
339419 t0 : int = 44 ,
340420 t1 : int = 47 ,
421+ frame_ids : Optional [List [int ]] = None ,
341422 ):
342423 """
343424 Function invoked when calling the pipeline for generation.
@@ -399,6 +480,9 @@ def __call__(
399480 t1 (`int`, *optional*, defaults to 47):
400481 Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
401482 [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
483+ frame_ids (`List[int]`, *optional*):
484+ Indexes of the frames that are being generated. This is used when generating longer videos
485+ chunk-by-chunk.
402486
403487 Returns:
404488 [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
@@ -407,7 +491,9 @@ def __call__(
407491 likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
408492 """
409493 assert video_length > 0
410- frame_ids = list (range (video_length ))
494+ if frame_ids is None :
495+ frame_ids = list (range (video_length ))
496+ assert len (frame_ids ) == video_length
411497
412498 assert num_videos_per_prompt == 1
413499
0 commit comments