Skip to content

Commit e0951f2

Browse files
committed
Refactor attention.CrossAttention to remove duplicate code and apply optimizations
Apply ~6% speedup by moving * self.scale to earlier on a smaller tensor. When we have enough VRAM don't make a useless zeros tensor. Switch between cuda/mps/cpu based on q.device.type to allow cleaner per architecture future optimizations. For cuda and cpu keep VRAM usage and faster slicing consistent. For cpu use smaller slices. Tested ~20% faster on i7, 9.8 to 7.7 s/it. Fix = typo to self.mem_total >= 8 in einsum_op_mps_v2 as per #582 discussion.
1 parent 100f2e8 commit e0951f2

File tree

1 file changed

+61
-77
lines changed

1 file changed

+61
-77
lines changed

ldm/modules/attention.py

Lines changed: 61 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def forward(self, x):
9090
b, c, h, w = x.shape
9191
qkv = self.to_qkv(x)
9292
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
93-
k = k.softmax(dim=-1)
93+
k = k.softmax(dim=-1)
9494
context = torch.einsum('bhdn,bhen->bhde', k, v)
9595
out = torch.einsum('bhde,bhdn->bhen', context, q)
9696
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@@ -167,101 +167,85 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
167167
nn.Linear(inner_dim, query_dim),
168168
nn.Dropout(dropout)
169169
)
170-
171-
if torch.cuda.is_available():
172-
self.einsum_op = self.einsum_op_cuda
173-
else:
174-
self.mem_total = psutil.virtual_memory().total / (1024**3)
175-
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
176-
177-
def einsum_op_compvis(self, q, k, v, r1):
178-
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
179-
s2 = s1.softmax(dim=-1, dtype=q.dtype)
180-
del s1
181-
r1 = einsum('b i j, b j d -> b i d', s2, v)
182-
del s2
183-
return r1
184-
185-
def einsum_op_mps_v1(self, q, k, v, r1):
170+
171+
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
172+
173+
def einsum_op_compvis(self, q, k, v):
174+
s = einsum('b i d, b j d -> b i j', q, k)
175+
s = s.softmax(dim=-1, dtype=s.dtype)
176+
return einsum('b i j, b j d -> b i d', s, v)
177+
178+
def einsum_op_slice_0(self, q, k, v, slice_size):
179+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
180+
for i in range(0, q.shape[0], slice_size):
181+
end = i + slice_size
182+
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
183+
return r
184+
185+
def einsum_op_slice_1(self, q, k, v, slice_size):
186+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
187+
for i in range(0, q.shape[1], slice_size):
188+
end = i + slice_size
189+
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
190+
return r
191+
192+
def einsum_op_mps_v1(self, q, k, v):
186193
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
187-
r1 = self.einsum_op_compvis(q, k, v, r1)
194+
return self.einsum_op_compvis(q, k, v)
188195
else:
189196
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
190-
for i in range(0, q.shape[1], slice_size):
191-
end = i + slice_size
192-
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
193-
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
194-
del s1
195-
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
196-
del s2
197-
return r1
198-
199-
def einsum_op_mps_v2(self, q, k, v, r1):
200-
if self.mem_total >= 8 and q.shape[1] <= 4096:
201-
r1 = self.einsum_op_compvis(q, k, v, r1)
197+
return self.einsum_op_slice_1(q, k, v, slice_size)
198+
199+
def einsum_op_mps_v2(self, q, k, v):
200+
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
201+
return self.einsum_op_compvis(q, k, v)
202202
else:
203-
slice_size = 1
204-
for i in range(0, q.shape[0], slice_size):
205-
end = min(q.shape[0], i + slice_size)
206-
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
207-
s1 *= self.scale
208-
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
209-
del s1
210-
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
211-
del s2
212-
return r1
213-
214-
def einsum_op_cuda(self, q, k, v, r1):
203+
return self.einsum_op_slice_0(q, k, v, 1)
204+
205+
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
206+
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
207+
if size_mb <= max_tensor_mb:
208+
return self.einsum_op_compvis(q, k, v)
209+
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
210+
if div <= q.shape[0]:
211+
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
212+
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
213+
214+
def einsum_op_cuda(self, q, k, v):
215215
stats = torch.cuda.memory_stats(q.device)
216216
mem_active = stats['active_bytes.all.current']
217217
mem_reserved = stats['reserved_bytes.all.current']
218-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
218+
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
219219
mem_free_torch = mem_reserved - mem_active
220220
mem_free_total = mem_free_cuda + mem_free_torch
221+
# Divide factor of safety as there's copying and fragmentation
222+
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
221223

222-
gb = 1024 ** 3
223-
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
224-
mem_required = tensor_size * 2.5
225-
steps = 1
224+
def einsum_op(self, q, k, v):
225+
if q.device.type == 'cuda':
226+
return self.einsum_op_cuda(q, k, v)
226227

227-
if mem_required > mem_free_total:
228-
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
228+
if q.device.type == 'mps':
229+
if self.mem_total_gb >= 32:
230+
return self.einsum_op_mps_v1(q, k, v)
231+
return self.einsum_op_mps_v2(q, k, v)
229232

230-
if steps > 64:
231-
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
232-
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
233-
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
234-
235-
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
236-
for i in range(0, q.shape[1], slice_size):
237-
end = min(q.shape[1], i + slice_size)
238-
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
239-
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
240-
del s1
241-
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
242-
del s2
243-
return r1
233+
# Smaller slices are faster due to L2/L3/SLC caches.
234+
# Tested on i7 with 8MB L3 cache.
235+
return self.einsum_op_tensor_mem(q, k, v, 32)
244236

245237
def forward(self, x, context=None, mask=None):
246238
h = self.heads
247239

248-
q_in = self.to_q(x)
240+
q = self.to_q(x)
249241
context = default(context, x)
250-
k_in = self.to_k(context)
251-
v_in = self.to_v(context)
252-
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
242+
k = self.to_k(context) * self.scale
243+
v = self.to_v(context)
253244
del context, x
254245

255-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
256-
del q_in, k_in, v_in
257-
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
258-
r1 = self.einsum_op(q, k, v, r1)
259-
del q, k, v
260-
261-
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
262-
del r1
263-
264-
return self.to_out(r2)
246+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
247+
r = self.einsum_op(q, k, v)
248+
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
265249

266250

267251
class BasicTransformerBlock(nn.Module):

0 commit comments

Comments
 (0)