Skip to content

Commit 70aa674

Browse files
committed
merge PR CompVis#495 - keep using float16 in ldm.modules.attention
1 parent 8748370 commit 70aa674

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

ldm/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def forward(self, x, context=None, mask=None):
181181
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
182182
del q_in, k_in, v_in
183183

184-
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
184+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
185185

186186
if device_type == 'mps':
187187
mem_free_total = psutil.virtual_memory().available
@@ -213,7 +213,7 @@ def forward(self, x, context=None, mask=None):
213213
end = i + slice_size
214214
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
215215

216-
s2 = s1.softmax(dim=-1)
216+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
217217
del s1
218218

219219
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)

scripts/dream.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
185185
continue
186186
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
187187
try:
188-
print(f'last seeds = {last_seeds}, opt.seed={opt.seed}')
189188
opt.seed = last_seeds[opt.seed]
190189
print(f'reusing previous seed {opt.seed}')
191190
except IndexError:

0 commit comments

Comments
 (0)