@@ -501,12 +501,22 @@ def _set_sampler(self):
501501
502502 def _load_model_from_config (self , config , ckpt ):
503503 print (f'>> Loading model from { ckpt } ' )
504+
505+ # for usage statistics
506+ device_type = choose_torch_device ()
507+ if device_type == 'cuda' :
508+ torch .cuda .reset_peak_memory_stats ()
509+ tic = time .time ()
510+
511+ # this does the work
504512 pl_sd = torch .load (ckpt , map_location = 'cpu' )
505513 sd = pl_sd ['state_dict' ]
506514 model = instantiate_from_config (config .model )
507515 m , u = model .load_state_dict (sd , strict = False )
508516 model .to (self .device )
509517 model .eval ()
518+
519+
510520 if self .full_precision :
511521 print (
512522 '>> Using slower but more accurate full-precision math (--full_precision)'
@@ -516,6 +526,20 @@ def _load_model_from_config(self, config, ckpt):
516526 '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
517527 )
518528 model .half ()
529+
530+ # usage statistics
531+ toc = time .time ()
532+ print (
533+ f'>> Model loaded in' , '%4.2fs' % (toc - tic )
534+ )
535+ if device_type == 'cuda' :
536+ print (
537+ '>> Max VRAM used to load the model:' ,
538+ '%4.2fG' % (torch .cuda .max_memory_allocated () / 1e9 ),
539+ '\n >> Current VRAM usage:'
540+ '%4.2fG' % (torch .cuda .memory_allocated () / 1e9 ),
541+ )
542+
519543 return model
520544
521545 def _load_img (self , path , width , height , fit = False ):
0 commit comments