Skip to content

Commit d98d04b

Browse files
NEW: Add Feature HoVerNet+ model (#179)
* Added HoVerNet+ model for the simultaneous semantic segmentation of layers and nuclear instance segmentation/classification. * Added HoVerNet+ model for the simultaneous semantic segmentation of layers and nuclear instance segmentation/classification. * Added HoVerNet+ model for the simultaneous semantic segmentation of layers and nuclear instance segmentation/classification. * Added HoVerNet+ model for the simultaneous semantic segmentation of layers and nuclear instance segmentation/classification. * BUG: Fixed bug in forward * BUG: Fixed bug in forward * BUG: Fixed bug in forward * BUG: Changed name of TP branch * DEV: Including model weight info * DOC: Added to model class docstrings. * DOC: Added to model class docstrings. * BUG: Removed unused imports. * DOC: Added to model class docstrings. * DOC: Added to model class docstrings. * FIX: Removed unnecessary dependencies. * FIX: Removed unnecessary dependencies. * UPD: Changed image used for testing. * UPD: Updated crop_op function.. * UPD: Updated test image. * FIX: Corrected formatting issue. * UPD: Updated unit testing for center_crop_to_shape. * UPD: Added unit testing for blocks. * UPD: Added unit testing for HoVer-Net+ model. * UPD: Added unit testing for HoVer-Net+ model. * UPD: Added unit testing for HoVer-Net+ model. * UPD: Correction to formatting. * UPD: Correction to formatting. * UPD: Updated unit testing. * UPD: Updated HoVer-Net+ unit testing. * UPD: Updated HoVer-Net+ unit testing. * UPD: Updated HoVer-Net+ unit testing. * UPD: Changed HoVer-Net+ to be a subclass of HoVer-Net. * UPD: Changed HoVer-Net+ to be a subclass of HoVer-Net. * UPD: Changed HoVer-Net+ to be a subclass of HoVer-Net. * UPD: Changed HoVer-Net+ to be a subclass of HoVer-Net. * UPD: Updated HoVer-Net+ unit testing. * UPD: Updated HoVer-Net+ and HoVer-Net docstrings. * BUG: Fixed error in formatting. * BUG: Removed repeated function in model utils. * UPD: Removed fast/original mode and extra branch options for HoVer-Net+. * UPD: Removed pred_dict extra branch options for HoVer-Net+. * FIX: Resolved utils.misc conflict. * UPD: Changed some functions to private static methods. * UPD: Changed private static methods to protected statis methods. * UPD: Added test for HoVer-Net+ with semantic segmentor. * UPD: Spelling mistake. * FIX: Added garbage collection to nuclear instance functionality test. * FIX: Updated TIALab to Centre * FIX: Updated TIALab to Centre * UPD: Updated HoVerNet/HoVerNet+ docstrings with Examples. Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent b87ac85 commit d98d04b

File tree

6 files changed

+548
-84
lines changed

6 files changed

+548
-84
lines changed

tests/models/test_hovernet_plus.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# ***** BEGIN GPL LICENSE BLOCK *****
2+
#
3+
# This program is free software; you can redistribute it and/or
4+
# modify it under the terms of the GNU General Public License
5+
# as published by the Free Software Foundation; either version 2
6+
# of the License, or (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU General Public License
14+
# along with this program; if not, write to the Free Software Foundation,
15+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
16+
#
17+
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
18+
# All rights reserved.
19+
# ***** END GPL LICENSE BLOCK *****
20+
21+
"""Unit test package for HoVerNet+."""
22+
23+
import pytest
24+
import torch
25+
26+
from tiatoolbox.models.architecture import fetch_pretrained_weights
27+
from tiatoolbox.models.architecture.hovernet_plus import HoVerNetPlus
28+
from tiatoolbox.utils.misc import imread
29+
from tiatoolbox.utils.transforms import imresize
30+
31+
32+
def test_functionality(remote_sample, tmp_path):
33+
"""Functionality test."""
34+
tmp_path = str(tmp_path)
35+
sample_patch = str(remote_sample("stainnorm-source"))
36+
patch_pre = imread(sample_patch)
37+
patch_pre = imresize(patch_pre, scale_factor=0.5)
38+
patch = patch_pre[0:256, 0:256]
39+
batch = torch.from_numpy(patch)[None]
40+
41+
# Test functionality with both nuclei and layer segmentation
42+
model = HoVerNetPlus(num_types=3, num_layers=5, mode="fast")
43+
# Test decoder as expected
44+
assert len(model.decoder["np"]) > 0, "Decoder must contain np branch."
45+
assert len(model.decoder["hv"]) > 0, "Decoder must contain hv branch."
46+
assert len(model.decoder["tp"]) > 0, "Decoder must contain tp branch."
47+
assert len(model.decoder["ls"]) > 0, "Decoder must contain ls branch."
48+
fetch_pretrained_weights("hovernetplus-oed", f"{tmp_path}/weigths.pth")
49+
pretrained = torch.load(f"{tmp_path}/weigths.pth")
50+
model.load_state_dict(pretrained)
51+
output = model.infer_batch(model, batch, on_gpu=False)
52+
assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
53+
output = [v[0] for v in output]
54+
output = model.postproc(output)
55+
assert len(output[1]) > 0 and len(output[3]) > 0, "Must have some nuclei/layers."
56+
57+
# test crash when providing exotic mode
58+
with pytest.raises(ValueError, match=r".*Invalid mode.*"):
59+
model = HoVerNetPlus(num_types=None, num_layers=None, mode="super")
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# ***** BEGIN GPL LICENSE BLOCK *****
2+
#
3+
# This program is free software; you can redistribute it and/or
4+
# modify it under the terms of the GNU General Public License
5+
# as published by the Free Software Foundation; either version 2
6+
# of the License, or (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU General Public License
14+
# along with this program; if not, write to the Free Software Foundation,
15+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
16+
#
17+
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
18+
# All rights reserved.
19+
# ***** END GPL LICENSE BLOCK *****
20+
21+
"""Unit test package for HoVerNet+."""
22+
23+
import pathlib
24+
import shutil
25+
26+
import numpy as np
27+
import pytest
28+
import torch
29+
30+
from tiatoolbox.models import SemanticSegmentor
31+
32+
BATCH_SIZE = 1
33+
ON_TRAVIS = True
34+
ON_GPU = not ON_TRAVIS and torch.cuda.is_available()
35+
36+
# ----------------------------------------------------
37+
38+
39+
def _rm_dir(path):
40+
"""Helper func to remove directory."""
41+
shutil.rmtree(path, ignore_errors=True)
42+
43+
44+
@pytest.mark.skip(reason="Local manual test, not applicable for travis.")
45+
def test_functionality_local(remote_sample, tmp_path):
46+
"""Local functionality test for multi task segmentor. Currently only
47+
testing HoVer-Net+ with semantic segmentor.
48+
"""
49+
root_save_dir = pathlib.Path(tmp_path)
50+
mini_wsi_svs = pathlib.Path(remote_sample("CMU-1-Small-Region.svs"))
51+
52+
save_dir = f"{root_save_dir}/semantic/"
53+
_rm_dir(save_dir)
54+
semantic_segmentor = SemanticSegmentor(
55+
pretrained_model="hovernetplus-oed",
56+
batch_size=BATCH_SIZE,
57+
num_postproc_workers=2,
58+
)
59+
output = semantic_segmentor.predict(
60+
[mini_wsi_svs],
61+
mode="wsi",
62+
on_gpu=True,
63+
crash_on_exception=True,
64+
save_dir=save_dir,
65+
)
66+
67+
raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(4)]
68+
inst_map, inst_dict, layer_map, layer_dict = semantic_segmentor.model.postproc(
69+
raw_maps
70+
)
71+
assert len(inst_dict) > 0 and len(layer_dict) > 0, "Must have some nuclei/layers."
72+
assert (
73+
inst_map.shape == layer_map.shape
74+
), "Output instance and layer maps must be same shape"
75+
_rm_dir(tmp_path)

tests/models/test_nucleus_instance_segmentor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def test_functionality_travis(remote_sample, tmp_path):
352352

353353
def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path):
354354
"""Functional tests for merging tile predictions."""
355+
gc.collect() # Force clean up everything on hold
355356
save_dir = pathlib.Path(f"{tmp_path}/output")
356357
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
357358

tiatoolbox/data/pretrained_model.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,28 @@ hovernet_original-kumar:
323323
patch_output_shape: [80, 80]
324324
stride_shape: [80, 80]
325325
save_resolution: {'units': 'mpp', 'resolution': 0.25}
326+
327+
hovernetplus-oed:
328+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/seg/hovernetplus-oed.pth
329+
architecture:
330+
class: hovernet_plus.HoVerNetPlus
331+
kwargs:
332+
num_types: 3
333+
num_layers: 5
334+
mode: "fast"
335+
ioconfig:
336+
class: semantic_segmentor.IOSegmentorConfig
337+
kwargs:
338+
input_resolutions:
339+
- {"units": "mpp", "resolution": 0.50}
340+
output_resolutions:
341+
- {"units": "mpp", "resolution": 0.50}
342+
- {"units": "mpp", "resolution": 0.50}
343+
- {"units": "mpp", "resolution": 0.50}
344+
- {"units": "mpp", "resolution": 0.50}
345+
margin: 128
346+
tile_shape: [1024, 1024]
347+
patch_input_shape: [256, 256]
348+
patch_output_shape: [164, 164]
349+
stride_shape: [164, 164]
350+
save_resolution: {'units': 'mpp', 'resolution': 0.50}

0 commit comments

Comments
 (0)