66
77from tiatoolbox .models .architecture .vanilla import CNNModel , TimmModel
88from tiatoolbox .models .models_abc import model_to
9- from tiatoolbox .utils .misc import select_device
109
1110ON_GPU = False
1211RNG = np .random .default_rng () # Numpy Random Generator
@@ -46,7 +45,7 @@ def test_functional() -> None:
4645 for backbone in backbones :
4746 model = CNNModel (backbone , num_classes = 1 )
4847 model_ = model_to (device = device , model = model )
49- model .infer_batch (model_ , samples , device = select_device ( on_gpu = ON_GPU ) )
48+ model .infer_batch (model_ , samples , device = device )
5049 except ValueError as exc :
5150 msg = f"Model { backbone } failed."
5251 raise AssertionError (msg ) from exc
@@ -72,8 +71,8 @@ def test_timm_functional() -> None:
7271 try :
7372 for backbone in backbones :
7473 model = TimmModel (backbone = backbone , num_classes = 1 , pretrained = False )
75- model_ = model_to (on_gpu = ON_GPU , model = model )
76- model .infer_batch (model_ , samples , on_gpu = ON_GPU )
74+ model_ = model_to (device = device , model = model )
75+ model .infer_batch (model_ , samples , device = device )
7776 except ValueError as exc :
7877 msg = f"Model { backbone } failed."
7978 raise AssertionError (msg ) from exc
0 commit comments