Skip to content

Commit dd2aeda

Browse files
authored
report VRAM usage stats during initial model loading (invoke-ai#419)
1 parent f628477 commit dd2aeda

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

ldm/generate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,12 +501,22 @@ def _set_sampler(self):
501501

502502
def _load_model_from_config(self, config, ckpt):
503503
print(f'>> Loading model from {ckpt}')
504+
505+
# for usage statistics
506+
device_type = choose_torch_device()
507+
if device_type == 'cuda':
508+
torch.cuda.reset_peak_memory_stats()
509+
tic = time.time()
510+
511+
# this does the work
504512
pl_sd = torch.load(ckpt, map_location='cpu')
505513
sd = pl_sd['state_dict']
506514
model = instantiate_from_config(config.model)
507515
m, u = model.load_state_dict(sd, strict=False)
508516
model.to(self.device)
509517
model.eval()
518+
519+
510520
if self.full_precision:
511521
print(
512522
'>> Using slower but more accurate full-precision math (--full_precision)'
@@ -516,6 +526,20 @@ def _load_model_from_config(self, config, ckpt):
516526
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
517527
)
518528
model.half()
529+
530+
# usage statistics
531+
toc = time.time()
532+
print(
533+
f'>> Model loaded in', '%4.2fs' % (toc - tic)
534+
)
535+
if device_type == 'cuda':
536+
print(
537+
'>> Max VRAM used to load the model:',
538+
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
539+
'\n>> Current VRAM usage:'
540+
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
541+
)
542+
519543
return model
520544

521545
def _load_img(self, path, width, height, fit=False):

scripts/dream.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,7 @@ def main():
9191
print(">> changed to seamless tiling mode")
9292

9393
# preload the model
94-
tic = time.time()
9594
t2i.load_model()
96-
print(
97-
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
98-
)
9995

10096
if not infile:
10197
print(

0 commit comments

Comments
 (0)