@@ -90,7 +90,7 @@ def forward(self, x):
9090 b , c , h , w = x .shape
9191 qkv = self .to_qkv (x )
9292 q , k , v = rearrange (qkv , 'b (qkv heads c) h w -> qkv b heads c (h w)' , heads = self .heads , qkv = 3 )
93- k = k .softmax (dim = - 1 )
93+ k = k .softmax (dim = - 1 )
9494 context = torch .einsum ('bhdn,bhen->bhde' , k , v )
9595 out = torch .einsum ('bhde,bhdn->bhen' , context , q )
9696 out = rearrange (out , 'b heads c (h w) -> b (heads c) h w' , heads = self .heads , h = h , w = w )
@@ -167,101 +167,85 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
167167 nn .Linear (inner_dim , query_dim ),
168168 nn .Dropout (dropout )
169169 )
170-
171- if torch .cuda .is_available ():
172- self .einsum_op = self .einsum_op_cuda
173- else :
174- self .mem_total = psutil .virtual_memory ().total / (1024 ** 3 )
175- self .einsum_op = self .einsum_op_mps_v1 if self .mem_total >= 32 else self .einsum_op_mps_v2
176-
177- def einsum_op_compvis (self , q , k , v , r1 ):
178- s1 = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale # faster
179- s2 = s1 .softmax (dim = - 1 , dtype = q .dtype )
180- del s1
181- r1 = einsum ('b i j, b j d -> b i d' , s2 , v )
182- del s2
183- return r1
184-
185- def einsum_op_mps_v1 (self , q , k , v , r1 ):
170+
171+ self .mem_total_gb = psutil .virtual_memory ().total // (1 << 30 )
172+
173+ def einsum_op_compvis (self , q , k , v ):
174+ s = einsum ('b i d, b j d -> b i j' , q , k )
175+ s = s .softmax (dim = - 1 , dtype = s .dtype )
176+ return einsum ('b i j, b j d -> b i d' , s , v )
177+
178+ def einsum_op_slice_0 (self , q , k , v , slice_size ):
179+ r = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
180+ for i in range (0 , q .shape [0 ], slice_size ):
181+ end = i + slice_size
182+ r [i :end ] = self .einsum_op_compvis (q [i :end ], k [i :end ], v [i :end ])
183+ return r
184+
185+ def einsum_op_slice_1 (self , q , k , v , slice_size ):
186+ r = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
187+ for i in range (0 , q .shape [1 ], slice_size ):
188+ end = i + slice_size
189+ r [:, i :end ] = self .einsum_op_compvis (q [:, i :end ], k , v )
190+ return r
191+
192+ def einsum_op_mps_v1 (self , q , k , v ):
186193 if q .shape [1 ] <= 4096 : # (512x512) max q.shape[1]: 4096
187- r1 = self .einsum_op_compvis (q , k , v , r1 )
194+ return self .einsum_op_compvis (q , k , v )
188195 else :
189196 slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
190- for i in range (0 , q .shape [1 ], slice_size ):
191- end = i + slice_size
192- s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
193- s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
194- del s1
195- r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
196- del s2
197- return r1
198-
199- def einsum_op_mps_v2 (self , q , k , v , r1 ):
200- if self .mem_total >= 8 and q .shape [1 ] <= 4096 :
201- r1 = self .einsum_op_compvis (q , k , v , r1 )
197+ return self .einsum_op_slice_1 (q , k , v , slice_size )
198+
199+ def einsum_op_mps_v2 (self , q , k , v ):
200+ if self .mem_total_gb > 8 and q .shape [1 ] <= 4096 :
201+ return self .einsum_op_compvis (q , k , v )
202202 else :
203- slice_size = 1
204- for i in range ( 0 , q . shape [ 0 ], slice_size ):
205- end = min ( q . shape [ 0 ], i + slice_size )
206- s1 = einsum ( 'b i d, b j d -> b i j' , q [ i : end ], k [ i : end ] )
207- s1 *= self . scale
208- s2 = s1 . softmax ( dim = - 1 , dtype = r1 . dtype )
209- del s1
210- r1 [ i : end ] = einsum ( 'b i j, b j d -> b i d' , s2 , v [ i : end ])
211- del s2
212- return r1
213-
214- def einsum_op_cuda (self , q , k , v , r1 ):
203+ return self . einsum_op_slice_0 ( q , k , v , 1 )
204+
205+ def einsum_op_tensor_mem ( self , q , k , v , max_tensor_mb ):
206+ size_mb = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size () // ( 1 << 20 )
207+ if size_mb <= max_tensor_mb :
208+ return self . einsum_op_compvis ( q , k , v )
209+ div = 1 << int (( size_mb - 1 ) / max_tensor_mb ). bit_length ()
210+ if div <= q . shape [ 0 ]:
211+ return self . einsum_op_slice_0 ( q , k , v , q . shape [ 0 ] // div )
212+ return self . einsum_op_slice_1 ( q , k , v , max ( q . shape [ 1 ] // div , 1 ))
213+
214+ def einsum_op_cuda (self , q , k , v ):
215215 stats = torch .cuda .memory_stats (q .device )
216216 mem_active = stats ['active_bytes.all.current' ]
217217 mem_reserved = stats ['reserved_bytes.all.current' ]
218- mem_free_cuda , _ = torch .cuda .mem_get_info (torch . cuda . current_device () )
218+ mem_free_cuda , _ = torch .cuda .mem_get_info (q . device )
219219 mem_free_torch = mem_reserved - mem_active
220220 mem_free_total = mem_free_cuda + mem_free_torch
221+ # Divide factor of safety as there's copying and fragmentation
222+ return self .einsum_op_tensor_mem (q , k , v , mem_free_total / 3.3 / (1 << 20 ))
221223
222- gb = 1024 ** 3
223- tensor_size = q .shape [0 ] * q .shape [1 ] * k .shape [1 ] * 4
224- mem_required = tensor_size * 2.5
225- steps = 1
224+ def einsum_op (self , q , k , v ):
225+ if q .device .type == 'cuda' :
226+ return self .einsum_op_cuda (q , k , v )
226227
227- if mem_required > mem_free_total :
228- steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
228+ if q .device .type == 'mps' :
229+ if self .mem_total_gb >= 32 :
230+ return self .einsum_op_mps_v1 (q , k , v )
231+ return self .einsum_op_mps_v2 (q , k , v )
229232
230- if steps > 64 :
231- max_res = math .floor (math .sqrt (math .sqrt (mem_free_total / 2.5 )) / 8 ) * 64
232- raise RuntimeError (f'Not enough memory, use lower resolution (max approx. { max_res } x{ max_res } ). '
233- f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
234-
235- slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
236- for i in range (0 , q .shape [1 ], slice_size ):
237- end = min (q .shape [1 ], i + slice_size )
238- s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
239- s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
240- del s1
241- r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
242- del s2
243- return r1
233+ # Smaller slices are faster due to L2/L3/SLC caches.
234+ # Tested on i7 with 8MB L3 cache.
235+ return self .einsum_op_tensor_mem (q , k , v , 32 )
244236
245237 def forward (self , x , context = None , mask = None ):
246238 h = self .heads
247239
248- q_in = self .to_q (x )
240+ q = self .to_q (x )
249241 context = default (context , x )
250- k_in = self .to_k (context )
251- v_in = self .to_v (context )
252- device_type = 'mps' if x .device .type == 'mps' else 'cuda'
242+ k = self .to_k (context ) * self .scale
243+ v = self .to_v (context )
253244 del context , x
254245
255- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q_in , k_in , v_in ))
256- del q_in , k_in , v_in
257- r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
258- r1 = self .einsum_op (q , k , v , r1 )
259- del q , k , v
260-
261- r2 = rearrange (r1 , '(b h) n d -> b n (h d)' , h = h )
262- del r1
263-
264- return self .to_out (r2 )
246+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q , k , v ))
247+ r = self .einsum_op (q , k , v )
248+ return self .to_out (rearrange (r , '(b h) n d -> b n (h d)' , h = h ))
265249
266250
267251class BasicTransformerBlock (nn .Module ):
0 commit comments