Skip to content

Commit a2de574

Browse files
committed
hacked version of attention_processor.py that batches ip_adapters
proof of concept implementation of hacking the attention processor. when it detects a batch in which the batch_size matches the number of ip_adapters, it attempts to run each ip_adapter on one element of the batch instead of squashing all ip_adapters together.
1 parent 1ca0a75 commit a2de574

File tree

1 file changed

+87
-49
lines changed

1 file changed

+87
-49
lines changed

src/diffusers/models/attention_processor.py

100644100755
Lines changed: 87 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,68 +3955,106 @@ def __call__(
39553955
else:
39563956
ip_adapter_masks = [None] * len(self.scale)
39573957

3958-
# for ip-adapter
3959-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
3960-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
3961-
):
3962-
skip = False
3963-
if isinstance(scale, list):
3964-
if all(s == 0 for s in scale):
3965-
skip = True
3966-
elif scale == 0:
3967-
skip = True
3968-
if not skip:
3969-
if mask is not None:
3970-
if not isinstance(scale, list):
3971-
scale = [scale] * mask.shape[1]
3958+
# HACK: if the number of ip_adapters match the batch, process as a special case
3959+
num_hidden_states = len(ip_hidden_states)
3960+
if batch_size == 2 * num_hidden_states:
3961+
result_keys = []
3962+
result_values = []
3963+
for i in range(batch_size):
3964+
current_ip_hidden_states = ip_hidden_states[i % num_hidden_states]
3965+
scale = self.scale[i % num_hidden_states]
3966+
to_k_ip = self.to_k_ip[i % num_hidden_states]
3967+
to_v_ip = self.to_v_ip[i % num_hidden_states]
3968+
mask = ip_adapter_masks[i % num_hidden_states]
3969+
3970+
ip_key = to_k_ip(current_ip_hidden_states[i, :, :, :])
3971+
ip_value = to_v_ip(current_ip_hidden_states[i, :, :, :])
3972+
3973+
ip_key = ip_key.view(2, -1, attn.heads, head_dim).transpose(1, 2)
3974+
ip_value = ip_value.view(2, -1, attn.heads, head_dim).transpose(1, 2)
3975+
3976+
result_keys.append(ip_key[0])
3977+
result_values.append(ip_value[0])
3978+
3979+
ip_key = torch.stack(result_keys, dim=0)
3980+
ip_value = torch.stack(result_values, dim=0)
3981+
3982+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
3983+
# TODO: add support for attn.scale when we move to Torch 2.1
3984+
current_ip_hidden_states = F.scaled_dot_product_attention(
3985+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
3986+
)
39723987

3973-
current_num_images = mask.shape[1]
3974-
for i in range(current_num_images):
3975-
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
3976-
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
3988+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
3989+
batch_size, -1, attn.heads * head_dim
3990+
)
3991+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
3992+
3993+
hidden_states = hidden_states + scale * current_ip_hidden_states
3994+
3995+
else:
3996+
# for ip-adapter
3997+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
3998+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
3999+
):
4000+
skip = False
4001+
if isinstance(scale, list):
4002+
if all(s == 0 for s in scale):
4003+
skip = True
4004+
elif scale == 0:
4005+
skip = True
4006+
if not skip:
4007+
if mask is not None:
4008+
if not isinstance(scale, list):
4009+
scale = [scale] * mask.shape[1]
4010+
4011+
current_num_images = mask.shape[1]
4012+
for i in range(current_num_images):
4013+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
4014+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
4015+
4016+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4017+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4018+
4019+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
4020+
# TODO: add support for attn.scale when we move to Torch 2.1
4021+
_current_ip_hidden_states = F.scaled_dot_product_attention(
4022+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
4023+
)
4024+
4025+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
4026+
batch_size, -1, attn.heads * head_dim
4027+
)
4028+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
4029+
4030+
mask_downsample = IPAdapterMaskProcessor.downsample(
4031+
mask[:, i, :, :],
4032+
batch_size,
4033+
_current_ip_hidden_states.shape[1],
4034+
_current_ip_hidden_states.shape[2],
4035+
)
4036+
4037+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4038+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4039+
else:
4040+
ip_key = to_k_ip(current_ip_hidden_states)
4041+
ip_value = to_v_ip(current_ip_hidden_states)
39774042

39784043
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
39794044
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
39804045

39814046
# the output of sdp = (batch, num_heads, seq_len, head_dim)
39824047
# TODO: add support for attn.scale when we move to Torch 2.1
3983-
_current_ip_hidden_states = F.scaled_dot_product_attention(
4048+
current_ip_hidden_states = F.scaled_dot_product_attention(
39844049
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
39854050
)
39864051

3987-
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
4052+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
39884053
batch_size, -1, attn.heads * head_dim
39894054
)
3990-
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
3991-
3992-
mask_downsample = IPAdapterMaskProcessor.downsample(
3993-
mask[:, i, :, :],
3994-
batch_size,
3995-
_current_ip_hidden_states.shape[1],
3996-
_current_ip_hidden_states.shape[2],
3997-
)
3998-
3999-
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4000-
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4001-
else:
4002-
ip_key = to_k_ip(current_ip_hidden_states)
4003-
ip_value = to_v_ip(current_ip_hidden_states)
4004-
4005-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4006-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4055+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
40074056

4008-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
4009-
# TODO: add support for attn.scale when we move to Torch 2.1
4010-
current_ip_hidden_states = F.scaled_dot_product_attention(
4011-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
4012-
)
4013-
4014-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
4015-
batch_size, -1, attn.heads * head_dim
4016-
)
4017-
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
4018-
4019-
hidden_states = hidden_states + scale * current_ip_hidden_states
4057+
hidden_states = hidden_states + scale * current_ip_hidden_states
40204058

40214059
# linear proj
40224060
hidden_states = attn.to_out[0](hidden_states)

0 commit comments

Comments
 (0)