Skip to content

Commit 6a3ee7c

Browse files
authored
Hub models map_location=device (#3894)
* Hub models `map_location=device` * cleanup
1 parent 8930e22 commit 6a3ee7c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

hubconf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
3636

3737
fname = Path(name).with_suffix('.pt') # checkpoint filename
3838
try:
39+
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
40+
3941
if pretrained and channels == 3 and classes == 80:
40-
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
42+
model = attempt_load(fname, map_location=device) # download/load FP32 model
4143
else:
4244
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
4345
model = Model(cfg, channels, classes) # create model
4446
if pretrained:
45-
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
47+
ckpt = torch.load(attempt_download(fname), map_location=device) # load
4648
msd = model.state_dict() # model state_dict
4749
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
4850
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
@@ -51,7 +53,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
5153
model.names = ckpt['model'].names # set class names attribute
5254
if autoshape:
5355
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
54-
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
5556
return model.to(device)
5657

5758
except Exception as e:

utils/torch_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import datetime
44
import logging
5-
import math
65
import os
76
import platform
87
import subprocess
@@ -11,6 +10,7 @@
1110
from copy import deepcopy
1211
from pathlib import Path
1312

13+
import math
1414
import torch
1515
import torch.backends.cudnn as cudnn
1616
import torch.distributed as dist
@@ -64,7 +64,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
6464
def select_device(device='', batch_size=None):
6565
# device = 'cpu' or '0' or '0,1,2,3'
6666
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
67-
cpu = device.lower() == 'cpu'
67+
device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
68+
cpu = device == 'cpu'
6869
if cpu:
6970
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
7071
elif device: # non-cpu device requested

0 commit comments

Comments
 (0)