66import pytest
77import torch
88
9- from tiatoolbox import utils
109from tiatoolbox .models import MicroNet , SemanticSegmentor
1110from tiatoolbox .models .architecture import fetch_pretrained_weights
1211from tiatoolbox .utils import env_detection as toolbox_env
12+ from tiatoolbox .utils .misc import select_device
1313from tiatoolbox .wsicore .wsireader import WSIReader
1414
15+ ON_GPU = toolbox_env .has_gpu ()
16+
1517
1618def test_functionality (remote_sample , tmp_path ):
1719 """Functionality test."""
@@ -28,10 +30,10 @@ def test_functionality(remote_sample, tmp_path):
2830 patch = model .preproc (patch )
2931 batch = torch .from_numpy (patch )[None ]
3032 fetch_pretrained_weights ("micronet-consep" , f"{ tmp_path } /weights.pth" )
31- map_location = utils . misc . select_device (utils . env_detection . has_gpu () )
33+ map_location = select_device (ON_GPU )
3234 pretrained = torch .load (f"{ tmp_path } /weights.pth" , map_location = map_location )
3335 model .load_state_dict (pretrained )
34- output = model .infer_batch (model , batch , on_gpu = False )
36+ output = model .infer_batch (model , batch , on_gpu = ON_GPU )
3537 output , _ = model .postproc (output [0 ])
3638 assert np .max (np .unique (output )) == 46
3739
@@ -43,7 +45,7 @@ def test_value_error():
4345
4446
4547@pytest .mark .skipif (
46- toolbox_env .running_on_ci () or not toolbox_env . has_gpu () ,
48+ toolbox_env .running_on_ci () or not ON_GPU ,
4749 reason = "Local test on machine with GPU." ,
4850)
4951def test_micronet_output (remote_sample , tmp_path ):
0 commit comments