1- import gc
21from inspect import isfunction
32import math
43import torch
87
98from ldm .modules .diffusionmodules .util import checkpoint
109
10+ import psutil
11+
1112
1213def exists (val ):
1314 return val is not None
@@ -151,14 +152,13 @@ def forward(self, x):
151152
152153
153154class CrossAttention (nn .Module ):
154- def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , att_step = 1 ):
155+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. ):
155156 super ().__init__ ()
156157 inner_dim = dim_head * heads
157158 context_dim = default (context_dim , query_dim )
158159
159160 self .scale = dim_head ** - 0.5
160161 self .heads = heads
161- self .att_step = att_step
162162
163163 self .to_q = nn .Linear (query_dim , inner_dim , bias = False )
164164 self .to_k = nn .Linear (context_dim , inner_dim , bias = False )
@@ -169,23 +169,50 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
169169 nn .Dropout (dropout )
170170 )
171171
172- def forward (self , x , context = None , mask = None ):
173- h = self .heads
174-
175- q_in = self .to_q (x )
176- context = default (context , x )
177-
178- k_in = self .to_k (context )
179- v_in = self .to_v (context )
180-
181- del context , x
182-
183- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q_in , k_in , v_in ))
184- del q_in , k_in , v_in
185-
186-
187- r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device )
188-
172+ if torch .cuda .is_available ():
173+ self .einsum_op = self .einsum_op_cuda
174+ else :
175+ self .mem_total = psutil .virtual_memory ().total / (1024 ** 3 )
176+ self .einsum_op = self .einsum_op_mps_v1 if self .mem_total >= 32 else self .einsum_op_mps_v2
177+
178+ def einsum_op_compvis (self , q , k , v , r1 ):
179+ s1 = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale # faster
180+ s2 = s1 .softmax (dim = - 1 , dtype = q .dtype )
181+ del s1
182+ r1 = einsum ('b i j, b j d -> b i d' , s2 , v )
183+ del s2
184+ return r1
185+
186+ def einsum_op_mps_v1 (self , q , k , v , r1 ):
187+ if q .shape [1 ] <= 4096 : # (512x512) max q.shape[1]: 4096
188+ r1 = self .einsum_op_compvis (q , k , v , r1 )
189+ else :
190+ slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
191+ for i in range (0 , q .shape [1 ], slice_size ):
192+ end = i + slice_size
193+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
194+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
195+ del s1
196+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
197+ del s2
198+ return r1
199+
200+ def einsum_op_mps_v2 (self , q , k , v , r1 ):
201+ if self .mem_total >= 8 and q .shape [1 ] <= 4096 :
202+ r1 = self .einsum_op_compvis (q , k , v , r1 )
203+ else :
204+ slice_size = 1
205+ for i in range (0 , q .shape [0 ], slice_size ):
206+ end = min (q .shape [0 ], i + slice_size )
207+ s1 = einsum ('b i d, b j d -> b i j' , q [i :end ], k [i :end ])
208+ s1 *= self .scale
209+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
210+ del s1
211+ r1 [i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v [i :end ])
212+ del s2
213+ return r1
214+
215+ def einsum_op_cuda (self , q , k , v , r1 ):
189216 stats = torch .cuda .memory_stats (q .device )
190217 mem_active = stats ['active_bytes.all.current' ]
191218 mem_reserved = stats ['reserved_bytes.all.current' ]
@@ -200,30 +227,39 @@ def forward(self, x, context=None, mask=None):
200227
201228 if mem_required > mem_free_total :
202229 steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
203- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
204- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
205230
206231 if steps > 64 :
207232 max_res = math .floor (math .sqrt (math .sqrt (mem_free_total / 2.5 )) / 8 ) * 64
208233 raise RuntimeError (f'Not enough memory, use lower resolution (max approx. { max_res } x{ max_res } ). '
209- f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
234+ f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
210235
211- slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
236+ slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
212237 for i in range (0 , q .shape [1 ], slice_size ):
213- end = i + slice_size
238+ end = min ( q . shape [ 1 ], i + slice_size )
214239 s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
215-
216- s2 = s1 .softmax (dim = - 1 )
240+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
217241 del s1
218-
219242 r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
220- del s2
243+ del s2
244+ return r1
221245
222- del q , k , v
246+ def forward (self , x , context = None , mask = None ):
247+ h = self .heads
223248
249+ q = self .to_q (x )
250+ context = default (context , x )
251+ del x
252+ k = self .to_k (context )
253+ v = self .to_v (context )
254+ del context
255+
256+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q , k , v ))
257+
258+ r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
259+ r1 = self .einsum_op (q , k , v , r1 )
260+ del q , k , v
224261 r2 = rearrange (r1 , '(b h) n d -> b n (h d)' , h = h )
225262 del r1
226-
227263 return self .to_out (r2 )
228264
229265
@@ -243,9 +279,10 @@ def forward(self, x, context=None):
243279 return checkpoint (self ._forward , (x , context ), self .parameters (), self .checkpoint )
244280
245281 def _forward (self , x , context = None ):
246- x = self .attn1 (self .norm1 (x )) + x
247- x = self .attn2 (self .norm2 (x ), context = context ) + x
248- x = self .ff (self .norm3 (x )) + x
282+ x = x .contiguous () if x .device .type == 'mps' else x
283+ x += self .attn1 (self .norm1 (x ))
284+ x += self .attn2 (self .norm2 (x ), context = context )
285+ x += self .ff (self .norm3 (x ))
249286 return x
250287
251288
@@ -292,4 +329,4 @@ def forward(self, x, context=None):
292329 x = block (x , context = context )
293330 x = rearrange (x , 'b (h w) c -> b c h w' , h = h , w = w )
294331 x = self .proj_out (x )
295- return x + x_in
332+ return x + x_in
0 commit comments