@@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906906 return self .processor (self , hidden_states )
907907
908908
909+ class MochiAttention (nn .Module ):
910+ def __init__ (
911+ self ,
912+ query_dim : int ,
913+ added_kv_proj_dim : int ,
914+ processor : "MochiAttnProcessor2_0" ,
915+ heads : int = 8 ,
916+ dim_head : int = 64 ,
917+ dropout : float = 0.0 ,
918+ bias : bool = False ,
919+ added_proj_bias : bool = True ,
920+ out_dim : Optional [int ] = None ,
921+ out_context_dim : Optional [int ] = None ,
922+ out_bias : bool = True ,
923+ context_pre_only : bool = False ,
924+ eps : float = 1e-5 ,
925+ ):
926+ super ().__init__ ()
927+ from .normalization import MochiRMSNorm
928+
929+ self .inner_dim = out_dim if out_dim is not None else dim_head * heads
930+ self .out_dim = out_dim if out_dim is not None else query_dim
931+ self .out_context_dim = out_context_dim if out_context_dim else query_dim
932+ self .context_pre_only = context_pre_only
933+
934+ self .heads = out_dim // dim_head if out_dim is not None else heads
935+
936+ self .norm_q = MochiRMSNorm (dim_head , eps , True )
937+ self .norm_k = MochiRMSNorm (dim_head , eps , True )
938+ self .norm_added_q = MochiRMSNorm (dim_head , eps , True )
939+ self .norm_added_k = MochiRMSNorm (dim_head , eps , True )
940+
941+ self .to_q = nn .Linear (query_dim , self .inner_dim , bias = bias )
942+ self .to_k = nn .Linear (query_dim , self .inner_dim , bias = bias )
943+ self .to_v = nn .Linear (query_dim , self .inner_dim , bias = bias )
944+
945+ self .add_k_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
946+ self .add_v_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
947+ if self .context_pre_only is not None :
948+ self .add_q_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
949+
950+ self .to_out = nn .ModuleList ([])
951+ self .to_out .append (nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
952+ self .to_out .append (nn .Dropout (dropout ))
953+
954+ if not self .context_pre_only :
955+ self .to_add_out = nn .Linear (self .inner_dim , self .out_context_dim , bias = out_bias )
956+
957+ self .processor = processor
958+
959+ def forward (
960+ self ,
961+ hidden_states : torch .Tensor ,
962+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
963+ attention_mask : Optional [torch .Tensor ] = None ,
964+ ** kwargs ,
965+ ):
966+ return self .processor (
967+ self ,
968+ hidden_states ,
969+ encoder_hidden_states = encoder_hidden_states ,
970+ attention_mask = attention_mask ,
971+ ** kwargs ,
972+ )
973+
974+
975+ class MochiAttnProcessor2_0 :
976+ """Attention processor used in Mochi."""
977+
978+ def __init__ (self ):
979+ if not hasattr (F , "scaled_dot_product_attention" ):
980+ raise ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
981+
982+ def __call__ (
983+ self ,
984+ attn : "MochiAttention" ,
985+ hidden_states : torch .Tensor ,
986+ encoder_hidden_states : torch .Tensor ,
987+ attention_mask : torch .Tensor ,
988+ image_rotary_emb : Optional [torch .Tensor ] = None ,
989+ ) -> torch .Tensor :
990+ query = attn .to_q (hidden_states )
991+ key = attn .to_k (hidden_states )
992+ value = attn .to_v (hidden_states )
993+
994+ query = query .unflatten (2 , (attn .heads , - 1 ))
995+ key = key .unflatten (2 , (attn .heads , - 1 ))
996+ value = value .unflatten (2 , (attn .heads , - 1 ))
997+
998+ if attn .norm_q is not None :
999+ query = attn .norm_q (query )
1000+ if attn .norm_k is not None :
1001+ key = attn .norm_k (key )
1002+
1003+ encoder_query = attn .add_q_proj (encoder_hidden_states )
1004+ encoder_key = attn .add_k_proj (encoder_hidden_states )
1005+ encoder_value = attn .add_v_proj (encoder_hidden_states )
1006+
1007+ encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 ))
1008+ encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 ))
1009+ encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 ))
1010+
1011+ if attn .norm_added_q is not None :
1012+ encoder_query = attn .norm_added_q (encoder_query )
1013+ if attn .norm_added_k is not None :
1014+ encoder_key = attn .norm_added_k (encoder_key )
1015+
1016+ if image_rotary_emb is not None :
1017+
1018+ def apply_rotary_emb (x , freqs_cos , freqs_sin ):
1019+ x_even = x [..., 0 ::2 ].float ()
1020+ x_odd = x [..., 1 ::2 ].float ()
1021+
1022+ cos = (x_even * freqs_cos - x_odd * freqs_sin ).to (x .dtype )
1023+ sin = (x_even * freqs_sin + x_odd * freqs_cos ).to (x .dtype )
1024+
1025+ return torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
1026+
1027+ query = apply_rotary_emb (query , * image_rotary_emb )
1028+ key = apply_rotary_emb (key , * image_rotary_emb )
1029+
1030+ query , key , value = query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
1031+ encoder_query , encoder_key , encoder_value = (
1032+ encoder_query .transpose (1 , 2 ),
1033+ encoder_key .transpose (1 , 2 ),
1034+ encoder_value .transpose (1 , 2 ),
1035+ )
1036+
1037+ sequence_length = query .size (2 )
1038+ encoder_sequence_length = encoder_query .size (2 )
1039+ total_length = sequence_length + encoder_sequence_length
1040+
1041+ batch_size , heads , _ , dim = query .shape
1042+ attn_outputs = []
1043+ for idx in range (batch_size ):
1044+ mask = attention_mask [idx ][None , :]
1045+ valid_prompt_token_indices = torch .nonzero (mask .flatten (), as_tuple = False ).flatten ()
1046+
1047+ valid_encoder_query = encoder_query [idx : idx + 1 , :, valid_prompt_token_indices , :]
1048+ valid_encoder_key = encoder_key [idx : idx + 1 , :, valid_prompt_token_indices , :]
1049+ valid_encoder_value = encoder_value [idx : idx + 1 , :, valid_prompt_token_indices , :]
1050+
1051+ valid_query = torch .cat ([query [idx : idx + 1 ], valid_encoder_query ], dim = 2 )
1052+ valid_key = torch .cat ([key [idx : idx + 1 ], valid_encoder_key ], dim = 2 )
1053+ valid_value = torch .cat ([value [idx : idx + 1 ], valid_encoder_value ], dim = 2 )
1054+
1055+ attn_output = F .scaled_dot_product_attention (
1056+ valid_query , valid_key , valid_value , dropout_p = 0.0 , is_causal = False
1057+ )
1058+ valid_sequence_length = attn_output .size (2 )
1059+ attn_output = F .pad (attn_output , (0 , 0 , 0 , total_length - valid_sequence_length ))
1060+ attn_outputs .append (attn_output )
1061+
1062+ hidden_states = torch .cat (attn_outputs , dim = 0 )
1063+ hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
1064+
1065+ hidden_states , encoder_hidden_states = hidden_states .split_with_sizes (
1066+ (sequence_length , encoder_sequence_length ), dim = 1
1067+ )
1068+
1069+ # linear proj
1070+ hidden_states = attn .to_out [0 ](hidden_states )
1071+ # dropout
1072+ hidden_states = attn .to_out [1 ](hidden_states )
1073+
1074+ if hasattr (attn , "to_add_out" ):
1075+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1076+
1077+ return hidden_states , encoder_hidden_states
1078+
1079+
9091080class AttnProcessor :
9101081 r"""
9111082 Default processor for performing attention-related computations.
@@ -3868,94 +4039,6 @@ def __call__(
38684039 return hidden_states
38694040
38704041
3871- class MochiAttnProcessor2_0 :
3872- """Attention processor used in Mochi."""
3873-
3874- def __init__ (self ):
3875- if not hasattr (F , "scaled_dot_product_attention" ):
3876- raise ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
3877-
3878- def __call__ (
3879- self ,
3880- attn : Attention ,
3881- hidden_states : torch .Tensor ,
3882- encoder_hidden_states : torch .Tensor ,
3883- attention_mask : Optional [torch .Tensor ] = None ,
3884- image_rotary_emb : Optional [torch .Tensor ] = None ,
3885- ) -> torch .Tensor :
3886- query = attn .to_q (hidden_states )
3887- key = attn .to_k (hidden_states )
3888- value = attn .to_v (hidden_states )
3889-
3890- query = query .unflatten (2 , (attn .heads , - 1 ))
3891- key = key .unflatten (2 , (attn .heads , - 1 ))
3892- value = value .unflatten (2 , (attn .heads , - 1 ))
3893-
3894- if attn .norm_q is not None :
3895- query = attn .norm_q (query )
3896- if attn .norm_k is not None :
3897- key = attn .norm_k (key )
3898-
3899- encoder_query = attn .add_q_proj (encoder_hidden_states )
3900- encoder_key = attn .add_k_proj (encoder_hidden_states )
3901- encoder_value = attn .add_v_proj (encoder_hidden_states )
3902-
3903- encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 ))
3904- encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 ))
3905- encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 ))
3906-
3907- if attn .norm_added_q is not None :
3908- encoder_query = attn .norm_added_q (encoder_query )
3909- if attn .norm_added_k is not None :
3910- encoder_key = attn .norm_added_k (encoder_key )
3911-
3912- if image_rotary_emb is not None :
3913-
3914- def apply_rotary_emb (x , freqs_cos , freqs_sin ):
3915- x_even = x [..., 0 ::2 ].float ()
3916- x_odd = x [..., 1 ::2 ].float ()
3917-
3918- cos = (x_even * freqs_cos - x_odd * freqs_sin ).to (x .dtype )
3919- sin = (x_even * freqs_sin + x_odd * freqs_cos ).to (x .dtype )
3920-
3921- return torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
3922-
3923- query = apply_rotary_emb (query , * image_rotary_emb )
3924- key = apply_rotary_emb (key , * image_rotary_emb )
3925-
3926- query , key , value = query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
3927- encoder_query , encoder_key , encoder_value = (
3928- encoder_query .transpose (1 , 2 ),
3929- encoder_key .transpose (1 , 2 ),
3930- encoder_value .transpose (1 , 2 ),
3931- )
3932-
3933- sequence_length = query .size (2 )
3934- encoder_sequence_length = encoder_query .size (2 )
3935-
3936- query = torch .cat ([query , encoder_query ], dim = 2 )
3937- key = torch .cat ([key , encoder_key ], dim = 2 )
3938- value = torch .cat ([value , encoder_value ], dim = 2 )
3939-
3940- hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
3941- hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
3942- hidden_states = hidden_states .to (query .dtype )
3943-
3944- hidden_states , encoder_hidden_states = hidden_states .split_with_sizes (
3945- (sequence_length , encoder_sequence_length ), dim = 1
3946- )
3947-
3948- # linear proj
3949- hidden_states = attn .to_out [0 ](hidden_states )
3950- # dropout
3951- hidden_states = attn .to_out [1 ](hidden_states )
3952-
3953- if getattr (attn , "to_add_out" , None ) is not None :
3954- encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
3955-
3956- return hidden_states , encoder_hidden_states
3957-
3958-
39594042class FusedAttnProcessor2_0 :
39604043 r"""
39614044 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5668,13 +5751,13 @@ def __call__(
56685751 AttnProcessorNPU ,
56695752 AttnProcessor2_0 ,
56705753 MochiVaeAttnProcessor2_0 ,
5754+ MochiAttnProcessor2_0 ,
56715755 StableAudioAttnProcessor2_0 ,
56725756 HunyuanAttnProcessor2_0 ,
56735757 FusedHunyuanAttnProcessor2_0 ,
56745758 PAGHunyuanAttnProcessor2_0 ,
56755759 PAGCFGHunyuanAttnProcessor2_0 ,
56765760 LuminaAttnProcessor2_0 ,
5677- MochiAttnProcessor2_0 ,
56785761 FusedAttnProcessor2_0 ,
56795762 CustomDiffusionXFormersAttnProcessor ,
56805763 CustomDiffusionAttnProcessor2_0 ,
0 commit comments