Skip to content

Commit 5c43988

Browse files
committed
reduce VRAM memory usage by half during model loading
* This moves the call to half() before model.to(device) to avoid GPU copy of full model. Improves speed and reduces memory usage dramatically * This fix contributed by @mh-dm (Mihai)
1 parent 9912270 commit 5c43988

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

ldm/generate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,6 @@ def _load_model_from_config(self, config, ckpt):
536536
sd = pl_sd['state_dict']
537537
model = instantiate_from_config(config.model)
538538
m, u = model.load_state_dict(sd, strict=False)
539-
model.to(self.device)
540-
model.eval()
541-
542539

543540
if self.full_precision:
544541
print(
@@ -549,6 +546,8 @@ def _load_model_from_config(self, config, ckpt):
549546
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
550547
)
551548
model.half()
549+
model.to(self.device)
550+
model.eval()
552551

553552
# usage statistics
554553
toc = time.time()

0 commit comments

Comments
 (0)