@@ -83,14 +83,16 @@ def forward(
8383        hidden_states : torch .FloatTensor ,
8484        temb : torch .FloatTensor ,
8585        image_rotary_emb = None ,
86+         joint_attention_kwargs = None ,
8687    ):
8788        residual  =  hidden_states 
8889        norm_hidden_states , gate  =  self .norm (hidden_states , emb = temb )
8990        mlp_hidden_states  =  self .act_mlp (self .proj_mlp (norm_hidden_states ))
90- 
91+          joint_attention_kwargs   =   joint_attention_kwargs   or  {} 
9192        attn_output  =  self .attn (
9293            hidden_states = norm_hidden_states ,
9394            image_rotary_emb = image_rotary_emb ,
95+             ** joint_attention_kwargs ,
9496        )
9597
9698        hidden_states  =  torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
@@ -161,18 +163,20 @@ def forward(
161163        encoder_hidden_states : torch .FloatTensor ,
162164        temb : torch .FloatTensor ,
163165        image_rotary_emb = None ,
166+         joint_attention_kwargs = None ,
164167    ):
165168        norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  self .norm1 (hidden_states , emb = temb )
166169
167170        norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp  =  self .norm1_context (
168171            encoder_hidden_states , emb = temb 
169172        )
170- 
173+          joint_attention_kwargs   =   joint_attention_kwargs   or  {} 
171174        # Attention. 
172175        attn_output , context_attn_output  =  self .attn (
173176            hidden_states = norm_hidden_states ,
174177            encoder_hidden_states = norm_encoder_hidden_states ,
175178            image_rotary_emb = image_rotary_emb ,
179+             ** joint_attention_kwargs ,
176180        )
177181
178182        # Process attention outputs for the `hidden_states`. 
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497501                    encoder_hidden_states = encoder_hidden_states ,
498502                    temb = temb ,
499503                    image_rotary_emb = image_rotary_emb ,
504+                     joint_attention_kwargs = joint_attention_kwargs ,
500505                )
501506
502507            # controlnet residual 
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533538                    hidden_states = hidden_states ,
534539                    temb = temb ,
535540                    image_rotary_emb = image_rotary_emb ,
541+                     joint_attention_kwargs = joint_attention_kwargs ,
536542                )
537543
538544            # controlnet residual 
0 commit comments