Skip to content

Commit 3263379

Browse files
authored
🧪 Improve micronet tests (#630)
- Improve micronet tests
1 parent d0d4ed6 commit 3263379

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/models/test_arch_micronet.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import pytest
77
import torch
88

9-
from tiatoolbox import utils
109
from tiatoolbox.models import MicroNet, SemanticSegmentor
1110
from tiatoolbox.models.architecture import fetch_pretrained_weights
1211
from tiatoolbox.utils import env_detection as toolbox_env
12+
from tiatoolbox.utils.misc import select_device
1313
from tiatoolbox.wsicore.wsireader import WSIReader
1414

15+
ON_GPU = toolbox_env.has_gpu()
16+
1517

1618
def test_functionality(remote_sample, tmp_path):
1719
"""Functionality test."""
@@ -28,10 +30,10 @@ def test_functionality(remote_sample, tmp_path):
2830
patch = model.preproc(patch)
2931
batch = torch.from_numpy(patch)[None]
3032
fetch_pretrained_weights("micronet-consep", f"{tmp_path}/weights.pth")
31-
map_location = utils.misc.select_device(utils.env_detection.has_gpu())
33+
map_location = select_device(ON_GPU)
3234
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
3335
model.load_state_dict(pretrained)
34-
output = model.infer_batch(model, batch, on_gpu=False)
36+
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
3537
output, _ = model.postproc(output[0])
3638
assert np.max(np.unique(output)) == 46
3739

@@ -43,7 +45,7 @@ def test_value_error():
4345

4446

4547
@pytest.mark.skipif(
46-
toolbox_env.running_on_ci() or not toolbox_env.has_gpu(),
48+
toolbox_env.running_on_ci() or not ON_GPU,
4749
reason="Local test on machine with GPU.",
4850
)
4951
def test_micronet_output(remote_sample, tmp_path):

0 commit comments

Comments
 (0)