|
| 1 | +# ***** BEGIN GPL LICENSE BLOCK ***** |
| 2 | +# |
| 3 | +# This program is free software; you can redistribute it and/or |
| 4 | +# modify it under the terms of the GNU General Public License |
| 5 | +# as published by the Free Software Foundation; either version 2 |
| 6 | +# of the License, or (at your option) any later version. |
| 7 | +# |
| 8 | +# This program is distributed in the hope that it will be useful, |
| 9 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 10 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 11 | +# GNU General Public License for more details. |
| 12 | +# |
| 13 | +# You should have received a copy of the GNU General Public License |
| 14 | +# along with this program; if not, write to the Free Software Foundation, |
| 15 | +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. |
| 16 | +# |
| 17 | +# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick |
| 18 | +# All rights reserved. |
| 19 | +# ***** END GPL LICENSE BLOCK ***** |
| 20 | + |
| 21 | +"""Unit test package for HoVerNet+.""" |
| 22 | + |
| 23 | +import pytest |
| 24 | +import torch |
| 25 | + |
| 26 | +from tiatoolbox.models.architecture import fetch_pretrained_weights |
| 27 | +from tiatoolbox.models.architecture.hovernet_plus import HoVerNetPlus |
| 28 | +from tiatoolbox.utils.misc import imread |
| 29 | +from tiatoolbox.utils.transforms import imresize |
| 30 | + |
| 31 | + |
| 32 | +def test_functionality(remote_sample, tmp_path): |
| 33 | + """Functionality test.""" |
| 34 | + tmp_path = str(tmp_path) |
| 35 | + sample_patch = str(remote_sample("stainnorm-source")) |
| 36 | + patch_pre = imread(sample_patch) |
| 37 | + patch_pre = imresize(patch_pre, scale_factor=0.5) |
| 38 | + patch = patch_pre[0:256, 0:256] |
| 39 | + batch = torch.from_numpy(patch)[None] |
| 40 | + |
| 41 | + # Test functionality with both nuclei and layer segmentation |
| 42 | + model = HoVerNetPlus(num_types=3, num_layers=5, mode="fast") |
| 43 | + # Test decoder as expected |
| 44 | + assert len(model.decoder["np"]) > 0, "Decoder must contain np branch." |
| 45 | + assert len(model.decoder["hv"]) > 0, "Decoder must contain hv branch." |
| 46 | + assert len(model.decoder["tp"]) > 0, "Decoder must contain tp branch." |
| 47 | + assert len(model.decoder["ls"]) > 0, "Decoder must contain ls branch." |
| 48 | + fetch_pretrained_weights("hovernetplus-oed", f"{tmp_path}/weigths.pth") |
| 49 | + pretrained = torch.load(f"{tmp_path}/weigths.pth") |
| 50 | + model.load_state_dict(pretrained) |
| 51 | + output = model.infer_batch(model, batch, on_gpu=False) |
| 52 | + assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." |
| 53 | + output = [v[0] for v in output] |
| 54 | + output = model.postproc(output) |
| 55 | + assert len(output[1]) > 0 and len(output[3]) > 0, "Must have some nuclei/layers." |
| 56 | + |
| 57 | + # test crash when providing exotic mode |
| 58 | + with pytest.raises(ValueError, match=r".*Invalid mode.*"): |
| 59 | + model = HoVerNetPlus(num_types=None, num_layers=None, mode="super") |
0 commit comments