@@ -168,30 +168,72 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
168168 nn .Dropout (dropout )
169169 )
170170
171- def forward (self , x , context = None , mask = None ):
172- h = self .heads
173-
174- q_in = self .to_q (x )
175- context = default (context , x )
176- k_in = self .to_k (context )
177- v_in = self .to_v (context )
178- device_type = 'mps' if x .device .type == 'mps' else 'cuda'
179- del context , x
180-
181- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q_in , k_in , v_in ))
182- del q_in , k_in , v_in
183-
184- r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
171+ if not torch .cuda .is_available ():
172+ mem_av = psutil .virtual_memory ().available / (1024 ** 3 )
173+ if mem_av > 32 :
174+ self .einsum_op = self .einsum_op_v1
175+ elif mem_av > 12 :
176+ self .einsum_op = self .einsum_op_v2
177+ else :
178+ self .einsum_op = self .einsum_op_v3
179+ del mem_av
180+ else :
181+ self .einsum_op = self .einsum_op_v4
185182
186- if device_type == 'mps' :
187- mem_free_total = psutil .virtual_memory ().available
183+ # mps 64-128 GB
184+ def einsum_op_v1 (self , q , k , v , r1 ):
185+ if q .shape [1 ] <= 4096 : # for 512x512: the max q.shape[1] is 4096
186+ s1 = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale # aggressive/faster: operation in one go
187+ s2 = s1 .softmax (dim = - 1 , dtype = q .dtype )
188+ del s1
189+ r1 = einsum ('b i j, b j d -> b i d' , s2 , v )
190+ del s2
188191 else :
189- stats = torch .cuda .memory_stats (q .device )
190- mem_active = stats ['active_bytes.all.current' ]
191- mem_reserved = stats ['reserved_bytes.all.current' ]
192- mem_free_cuda , _ = torch .cuda .mem_get_info (torch .cuda .current_device ())
193- mem_free_torch = mem_reserved - mem_active
194- mem_free_total = mem_free_cuda + mem_free_torch
192+ # q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
193+ # needs around half of that slice_size to not generate noise
194+ slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
195+ for i in range (0 , q .shape [1 ], slice_size ):
196+ end = i + slice_size
197+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
198+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
199+ del s1
200+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
201+ del s2
202+ return r1
203+
204+ # mps 16-32 GB (can be optimized)
205+ def einsum_op_v2 (self , q , k , v , r1 ):
206+ slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
207+ for i in range (0 , q .shape [1 ], slice_size ): # conservative/less mem: operation in steps
208+ end = i + slice_size
209+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
210+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
211+ del s1
212+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
213+ del s2
214+ return r1
215+
216+ # mps 8 GB
217+ def einsum_op_v3 (self , q , k , v , r1 ):
218+ slice_size = 1
219+ for i in range (0 , q .shape [0 ], slice_size ): # iterate over q.shape[0]
220+ end = min (q .shape [0 ], i + slice_size )
221+ s1 = einsum ('b i d, b j d -> b i j' , q [i :end ], k [i :end ]) # adapted einsum for mem
222+ s1 *= self .scale
223+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
224+ del s1
225+ r1 [i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v [i :end ]) # adapted einsum for mem
226+ del s2
227+ return r1
228+
229+ # cuda
230+ def einsum_op_v4 (self , q , k , v , r1 ):
231+ stats = torch .cuda .memory_stats (q .device )
232+ mem_active = stats ['active_bytes.all.current' ]
233+ mem_reserved = stats ['reserved_bytes.all.current' ]
234+ mem_free_cuda , _ = torch .cuda .mem_get_info (torch .cuda .current_device ())
235+ mem_free_torch = mem_reserved - mem_active
236+ mem_free_total = mem_free_cuda + mem_free_torch
195237
196238 gb = 1024 ** 3
197239 tensor_size = q .shape [0 ] * q .shape [1 ] * k .shape [1 ] * 4
@@ -200,25 +242,36 @@ def forward(self, x, context=None, mask=None):
200242
201243 if mem_required > mem_free_total :
202244 steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
203- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
204- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
205245
206246 if steps > 64 :
207247 max_res = math .floor (math .sqrt (math .sqrt (mem_free_total / 2.5 )) / 8 ) * 64
208248 raise RuntimeError (f'Not enough memory, use lower resolution (max approx. { max_res } x{ max_res } ). '
209- f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
210-
211- slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
249+ f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
250+
251+ slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
212252 for i in range (0 , q .shape [1 ], slice_size ):
213- end = i + slice_size
253+ end = min ( q . shape [ 1 ], i + slice_size )
214254 s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
215-
216255 s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
217256 del s1
218-
219257 r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
220- del s2
258+ del s2
259+ return r1
260+
261+ def forward (self , x , context = None , mask = None ):
262+ h = self .heads
221263
264+ q_in = self .to_q (x )
265+ context = default (context , x )
266+ k_in = self .to_k (context )
267+ v_in = self .to_v (context )
268+ device_type = 'mps' if x .device .type == 'mps' else 'cuda'
269+ del context , x
270+
271+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q_in , k_in , v_in ))
272+ del q_in , k_in , v_in
273+ r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
274+ r1 = self .einsum_op (q , k , v , r1 )
222275 del q , k , v
223276
224277 r2 = rearrange (r1 , '(b h) n d -> b n (h d)' , h = h )
0 commit comments