@@ -3955,68 +3955,106 @@ def __call__(
39553955 else :
39563956 ip_adapter_masks = [None ] * len (self .scale )
39573957
3958- # for ip-adapter
3959- for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
3960- ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
3961- ):
3962- skip = False
3963- if isinstance (scale , list ):
3964- if all (s == 0 for s in scale ):
3965- skip = True
3966- elif scale == 0 :
3967- skip = True
3968- if not skip :
3969- if mask is not None :
3970- if not isinstance (scale , list ):
3971- scale = [scale ] * mask .shape [1 ]
3958+ # HACK: if the number of ip_adapters match the batch, process as a special case
3959+ num_hidden_states = len (ip_hidden_states )
3960+ if batch_size == 2 * num_hidden_states :
3961+ result_keys = []
3962+ result_values = []
3963+ for i in range (batch_size ):
3964+ current_ip_hidden_states = ip_hidden_states [i % num_hidden_states ]
3965+ scale = self .scale [i % num_hidden_states ]
3966+ to_k_ip = self .to_k_ip [i % num_hidden_states ]
3967+ to_v_ip = self .to_v_ip [i % num_hidden_states ]
3968+ mask = ip_adapter_masks [i % num_hidden_states ]
3969+
3970+ ip_key = to_k_ip (current_ip_hidden_states [i , :, :, :])
3971+ ip_value = to_v_ip (current_ip_hidden_states [i , :, :, :])
3972+
3973+ ip_key = ip_key .view (2 , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3974+ ip_value = ip_value .view (2 , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3975+
3976+ result_keys .append (ip_key [0 ])
3977+ result_values .append (ip_value [0 ])
3978+
3979+ ip_key = torch .stack (result_keys , dim = 0 )
3980+ ip_value = torch .stack (result_values , dim = 0 )
3981+
3982+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3983+ # TODO: add support for attn.scale when we move to Torch 2.1
3984+ current_ip_hidden_states = F .scaled_dot_product_attention (
3985+ query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
3986+ )
39723987
3973- current_num_images = mask .shape [1 ]
3974- for i in range (current_num_images ):
3975- ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
3976- ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
3988+ current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
3989+ batch_size , - 1 , attn .heads * head_dim
3990+ )
3991+ current_ip_hidden_states = current_ip_hidden_states .to (query .dtype )
3992+
3993+ hidden_states = hidden_states + scale * current_ip_hidden_states
3994+
3995+ else :
3996+ # for ip-adapter
3997+ for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
3998+ ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
3999+ ):
4000+ skip = False
4001+ if isinstance (scale , list ):
4002+ if all (s == 0 for s in scale ):
4003+ skip = True
4004+ elif scale == 0 :
4005+ skip = True
4006+ if not skip :
4007+ if mask is not None :
4008+ if not isinstance (scale , list ):
4009+ scale = [scale ] * mask .shape [1 ]
4010+
4011+ current_num_images = mask .shape [1 ]
4012+ for i in range (current_num_images ):
4013+ ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
4014+ ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
4015+
4016+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
4017+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
4018+
4019+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
4020+ # TODO: add support for attn.scale when we move to Torch 2.1
4021+ _current_ip_hidden_states = F .scaled_dot_product_attention (
4022+ query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
4023+ )
4024+
4025+ _current_ip_hidden_states = _current_ip_hidden_states .transpose (1 , 2 ).reshape (
4026+ batch_size , - 1 , attn .heads * head_dim
4027+ )
4028+ _current_ip_hidden_states = _current_ip_hidden_states .to (query .dtype )
4029+
4030+ mask_downsample = IPAdapterMaskProcessor .downsample (
4031+ mask [:, i , :, :],
4032+ batch_size ,
4033+ _current_ip_hidden_states .shape [1 ],
4034+ _current_ip_hidden_states .shape [2 ],
4035+ )
4036+
4037+ mask_downsample = mask_downsample .to (dtype = query .dtype , device = query .device )
4038+ hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
4039+ else :
4040+ ip_key = to_k_ip (current_ip_hidden_states )
4041+ ip_value = to_v_ip (current_ip_hidden_states )
39774042
39784043 ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
39794044 ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
39804045
39814046 # the output of sdp = (batch, num_heads, seq_len, head_dim)
39824047 # TODO: add support for attn.scale when we move to Torch 2.1
3983- _current_ip_hidden_states = F .scaled_dot_product_attention (
4048+ current_ip_hidden_states = F .scaled_dot_product_attention (
39844049 query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
39854050 )
39864051
3987- _current_ip_hidden_states = _current_ip_hidden_states .transpose (1 , 2 ).reshape (
4052+ current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
39884053 batch_size , - 1 , attn .heads * head_dim
39894054 )
3990- _current_ip_hidden_states = _current_ip_hidden_states .to (query .dtype )
3991-
3992- mask_downsample = IPAdapterMaskProcessor .downsample (
3993- mask [:, i , :, :],
3994- batch_size ,
3995- _current_ip_hidden_states .shape [1 ],
3996- _current_ip_hidden_states .shape [2 ],
3997- )
3998-
3999- mask_downsample = mask_downsample .to (dtype = query .dtype , device = query .device )
4000- hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
4001- else :
4002- ip_key = to_k_ip (current_ip_hidden_states )
4003- ip_value = to_v_ip (current_ip_hidden_states )
4004-
4005- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
4006- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
4055+ current_ip_hidden_states = current_ip_hidden_states .to (query .dtype )
40074056
4008- # the output of sdp = (batch, num_heads, seq_len, head_dim)
4009- # TODO: add support for attn.scale when we move to Torch 2.1
4010- current_ip_hidden_states = F .scaled_dot_product_attention (
4011- query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
4012- )
4013-
4014- current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
4015- batch_size , - 1 , attn .heads * head_dim
4016- )
4017- current_ip_hidden_states = current_ip_hidden_states .to (query .dtype )
4018-
4019- hidden_states = hidden_states + scale * current_ip_hidden_states
4057+ hidden_states = hidden_states + scale * current_ip_hidden_states
40204058
40214059 # linear proj
40224060 hidden_states = attn .to_out [0 ](hidden_states )
0 commit comments