Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 31 additions & 47 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,25 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

if not torch.cuda.is_available():
mem_av = psutil.virtual_memory().available / (1024**3)
if mem_av > 32:
self.einsum_op = self.einsum_op_v1
elif mem_av > 12:
self.einsum_op = self.einsum_op_v2
else:
self.einsum_op = self.einsum_op_v3
del mem_av

if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.einsum_op = self.einsum_op_v4
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2

def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1

# mps 64-128 GB
def einsum_op_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
def einsum_op_mps_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
# needs around half of that slice_size to not generate noise
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
Expand All @@ -201,33 +196,22 @@ def einsum_op_v1(self, q, k, v, r1):
del s2
return r1

# mps 16-32 GB (can be optimized)
def einsum_op_v2(self, q, k, v, r1):
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1

# mps 8 GB
def einsum_op_v3(self, q, k, v, r1):
slice_size = 1
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
del s2
def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1

# cuda
def einsum_op_v4(self, q, k, v, r1):

def einsum_op_cuda(self, q, k, v, r1):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
Expand Down