From bcdca2aa766a266b9d40adaaecd041640ea1fb2b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 3 Jul 2023 16:59:01 +0100 Subject: [PATCH] :test_tube: Improve micronet tests - Improve micronet tests Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_arch_micronet.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index 5a0724db2..5abc5dd16 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -6,12 +6,14 @@ import pytest import torch -from tiatoolbox import utils from tiatoolbox.models import MicroNet, SemanticSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader +ON_GPU = toolbox_env.has_gpu() + def test_functionality(remote_sample, tmp_path): """Functionality test.""" @@ -28,10 +30,10 @@ def test_functionality(remote_sample, tmp_path): patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] fetch_pretrained_weights("micronet-consep", f"{tmp_path}/weights.pth") - map_location = utils.misc.select_device(utils.env_detection.has_gpu()) + map_location = select_device(ON_GPU) pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, on_gpu=ON_GPU) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 @@ -43,7 +45,7 @@ def test_value_error(): @pytest.mark.skipif( - toolbox_env.running_on_ci() or not toolbox_env.has_gpu(), + toolbox_env.running_on_ci() or not ON_GPU, reason="Local test on machine with GPU.", ) def test_micronet_output(remote_sample, tmp_path):