Skip to content

Commit 9cdf3ac

Browse files
Any-Winter-4079lstein
authored andcommitted
Update attention.py
Performance improvements to generate larger images in M1 CompVis#431 Update attention.py Added dtype=r1.dtype to softmax
1 parent 49a96b9 commit 9cdf3ac

File tree

1 file changed

+84
-31
lines changed

1 file changed

+84
-31
lines changed

ldm/modules/attention.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -168,30 +168,72 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
168168
nn.Dropout(dropout)
169169
)
170170

171-
def forward(self, x, context=None, mask=None):
172-
h = self.heads
173-
174-
q_in = self.to_q(x)
175-
context = default(context, x)
176-
k_in = self.to_k(context)
177-
v_in = self.to_v(context)
178-
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
179-
del context, x
180-
181-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
182-
del q_in, k_in, v_in
183-
184-
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
171+
if not torch.cuda.is_available():
172+
mem_av = psutil.virtual_memory().available / (1024**3)
173+
if mem_av > 32:
174+
self.einsum_op = self.einsum_op_v1
175+
elif mem_av > 12:
176+
self.einsum_op = self.einsum_op_v2
177+
else:
178+
self.einsum_op = self.einsum_op_v3
179+
del mem_av
180+
else:
181+
self.einsum_op = self.einsum_op_v4
185182

186-
if device_type == 'mps':
187-
mem_free_total = psutil.virtual_memory().available
183+
# mps 64-128 GB
184+
def einsum_op_v1(self, q, k, v, r1):
185+
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
186+
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
187+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
188+
del s1
189+
r1 = einsum('b i j, b j d -> b i d', s2, v)
190+
del s2
188191
else:
189-
stats = torch.cuda.memory_stats(q.device)
190-
mem_active = stats['active_bytes.all.current']
191-
mem_reserved = stats['reserved_bytes.all.current']
192-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
193-
mem_free_torch = mem_reserved - mem_active
194-
mem_free_total = mem_free_cuda + mem_free_torch
192+
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
193+
# needs around half of that slice_size to not generate noise
194+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
195+
for i in range(0, q.shape[1], slice_size):
196+
end = i + slice_size
197+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
198+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
199+
del s1
200+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
201+
del s2
202+
return r1
203+
204+
# mps 16-32 GB (can be optimized)
205+
def einsum_op_v2(self, q, k, v, r1):
206+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
207+
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
208+
end = i + slice_size
209+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
210+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
211+
del s1
212+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
213+
del s2
214+
return r1
215+
216+
# mps 8 GB
217+
def einsum_op_v3(self, q, k, v, r1):
218+
slice_size = 1
219+
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
220+
end = min(q.shape[0], i + slice_size)
221+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
222+
s1 *= self.scale
223+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
224+
del s1
225+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
226+
del s2
227+
return r1
228+
229+
# cuda
230+
def einsum_op_v4(self, q, k, v, r1):
231+
stats = torch.cuda.memory_stats(q.device)
232+
mem_active = stats['active_bytes.all.current']
233+
mem_reserved = stats['reserved_bytes.all.current']
234+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
235+
mem_free_torch = mem_reserved - mem_active
236+
mem_free_total = mem_free_cuda + mem_free_torch
195237

196238
gb = 1024 ** 3
197239
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
@@ -200,25 +242,36 @@ def forward(self, x, context=None, mask=None):
200242

201243
if mem_required > mem_free_total:
202244
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
203-
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
204-
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
205245

206246
if steps > 64:
207247
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
208248
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
209-
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
210-
211-
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
249+
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
250+
251+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
212252
for i in range(0, q.shape[1], slice_size):
213-
end = i + slice_size
253+
end = min(q.shape[1], i + slice_size)
214254
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
215-
216255
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
217256
del s1
218-
219257
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
220-
del s2
258+
del s2
259+
return r1
260+
261+
def forward(self, x, context=None, mask=None):
262+
h = self.heads
221263

264+
q_in = self.to_q(x)
265+
context = default(context, x)
266+
k_in = self.to_k(context)
267+
v_in = self.to_v(context)
268+
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
269+
del context, x
270+
271+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
272+
del q_in, k_in, v_in
273+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
274+
r1 = self.einsum_op(q, k, v, r1)
222275
del q, k, v
223276

224277
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)

0 commit comments

Comments
 (0)