Skip to content

Commit dba269c

Browse files
committed
🐛 Fix model_to device specification
1 parent 7b4f496 commit dba269c

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/models/test_arch_vanilla.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
88
from tiatoolbox.models.models_abc import model_to
9-
from tiatoolbox.utils.misc import select_device
109

1110
ON_GPU = False
1211
RNG = np.random.default_rng() # Numpy Random Generator
@@ -46,7 +45,7 @@ def test_functional() -> None:
4645
for backbone in backbones:
4746
model = CNNModel(backbone, num_classes=1)
4847
model_ = model_to(device=device, model=model)
49-
model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU))
48+
model.infer_batch(model_, samples, device=device)
5049
except ValueError as exc:
5150
msg = f"Model {backbone} failed."
5251
raise AssertionError(msg) from exc
@@ -72,8 +71,8 @@ def test_timm_functional() -> None:
7271
try:
7372
for backbone in backbones:
7473
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
75-
model_ = model_to(on_gpu=ON_GPU, model=model)
76-
model.infer_batch(model_, samples, on_gpu=ON_GPU)
74+
model_ = model_to(device=device, model=model)
75+
model.infer_batch(model_, samples, device=device)
7776
except ValueError as exc:
7877
msg = f"Model {backbone} failed."
7978
raise AssertionError(msg) from exc

0 commit comments

Comments
 (0)