Skip to content

Commit d98e091

Browse files
authored
NEW: Feature CLI For Nucleus Instance Segmentation (#194)
- Add cli for nucleus instance segmentation - Add tests for cli and remove tests if cli is covering these. - Fix link for pretrained models in semantic segmentation - Rearrange add commands in cli - Replace `get_wsireader` with `WSIReader.open` as `get_wsireader` is being depreciated. - Update installation instructions for pygeos. - Fix OOM errors using garbage collection `gc.collect()`. Co-authored by: @shaneahmed
1 parent 27280fa commit d98e091

File tree

8 files changed

+344
-66
lines changed

8 files changed

+344
-66
lines changed

docs/installation.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,19 @@ Windows
2929
1. Download OpenSlide binaries from `this page <https://openslide.org/download/>`_. Extract the folder and add `bin` and `lib` subdirectories to
3030
Windows `system path <https://docs.microsoft.com/en-us/previous-versions/office/developer/sharepoint-2010/ee537574(v=office.14)>`_.
3131

32-
2. Install OpenJPEG. The easiest way is to install OpenJpeg is through conda
32+
2. Install
33+
TIAToolbox.
34+
35+
.. code-block:: console
36+
37+
$ pip install tiatoolbox
38+
39+
3. Install OpenJPEG. The easiest way is to install OpenJpeg is through conda
3340
using
3441

3542
.. code-block:: console
3643
37-
C:\> conda install -c conda-forge openjpeg>=2.3.0
44+
C:\> conda install -c conda-forge openjpeg pygeos
3845
3946
Linux (Ubuntu)
4047
^^^^^^^^^^^^^^

tests/models/test_nucleus_instance_segmentor.py

Lines changed: 135 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,24 @@
2727
import numpy as np
2828
import pytest
2929
import torch
30+
import yaml
31+
from click.testing import CliRunner
3032

33+
from tiatoolbox import cli
3134
from tiatoolbox.models import (
3235
IOSegmentorConfig,
3336
NucleusInstanceSegmentor,
3437
SemanticSegmentor,
3538
)
39+
from tiatoolbox.models.architecture import fetch_pretrained_weights
3640
from tiatoolbox.models.engine.nucleus_instance_segmentor import (
3741
_process_tile_predictions,
3842
)
3943
from tiatoolbox.utils.metrics import f1_detection
4044
from tiatoolbox.utils.misc import imwrite
41-
from tiatoolbox.wsicore.wsireader import get_wsireader
45+
from tiatoolbox.wsicore.wsireader import WSIReader
4246

43-
BATCH_SIZE = 2
47+
BATCH_SIZE = 1
4448
ON_TRAVIS = True
4549
ON_GPU = not ON_TRAVIS and torch.cuda.is_available()
4650

@@ -57,11 +61,8 @@ def _crash_func(x):
5761
raise ValueError("Propataion Crash.")
5862

5963

60-
# ----------------------------------------------------
61-
62-
63-
def test_get_tile_info():
64-
"""Test for getting tile info."""
64+
def helper_tile_info():
65+
"""Helper function for tile information."""
6566
predictor = NucleusInstanceSegmentor(model="A")
6667
# ! assuming the tiles organized as follows (coming out from
6768
# ! PatchExtractor). If this is broken, need to check back
@@ -76,7 +77,6 @@ def test_get_tile_info():
7677
# ---------------------
7778
# | 12 | 13 | 14 | 15 |
7879
# ---------------------
79-
8080
# ! assume flag index ordering: left right top bottom
8181
ioconfig = IOSegmentorConfig(
8282
input_resolutions=[{"units": "mpp", "resolution": 0.25}],
@@ -91,8 +91,17 @@ def test_get_tile_info():
9191
patch_input_shape=[4, 4],
9292
patch_output_shape=[4, 4],
9393
)
94-
info = predictor._get_tile_info([16, 16], ioconfig)
95-
boxes, flag = info[0] # index 0 should be full grid, removal
94+
95+
return predictor._get_tile_info([16, 16], ioconfig)
96+
97+
98+
# ----------------------------------------------------
99+
100+
101+
def test_get_tile_info():
102+
"""Test for getting tile info."""
103+
info = helper_tile_info()
104+
_, flag = info[0] # index 0 should be full grid, removal
96105
# removal flag at top edges
97106
assert (
98107
np.sum(
@@ -124,7 +133,10 @@ def test_get_tile_info():
124133
== 0
125134
), "Fail Right"
126135

127-
# test for vertical boundary boxes
136+
137+
def test_vertical_boundary_boxes():
138+
"""Test for vertical boundary boxes."""
139+
info = helper_tile_info()
128140
_boxes = np.array(
129141
[
130142
[3, 0, 5, 4],
@@ -161,7 +173,10 @@ def test_get_tile_info():
161173
assert np.sum(_boxes - boxes) == 0, "Wrong Vertical Bounds"
162174
assert np.sum(flag - _flag) == 0, "Fail Vertical Flag"
163175

164-
# test for horizontal boundary boxes
176+
177+
def test_horizontal_boundary_boxes():
178+
"""Test for horizontal boundary boxes."""
179+
info = helper_tile_info()
165180
_boxes = np.array(
166181
[
167182
[0, 3, 4, 5],
@@ -198,7 +213,10 @@ def test_get_tile_info():
198213
assert np.sum(_boxes - boxes) == 0, "Wrong Horizontal Bounds"
199214
assert np.sum(flag - _flag) == 0, "Fail Horizontal Flag"
200215

201-
# test for cross-section boundary boxes
216+
217+
def test_cross_section_boundary_boxes():
218+
"""Test for cross-section boundary boxes."""
219+
info = helper_tile_info()
202220
_boxes = np.array(
203221
[
204222
[2, 2, 6, 6],
@@ -233,8 +251,11 @@ def test_get_tile_info():
233251
def test_crash_segmentor(remote_sample, tmp_path):
234252
"""Test engine crash when given malformed input."""
235253
root_save_dir = pathlib.Path(tmp_path)
236-
sample_wsi_svs = pathlib.Path(remote_sample("wsi2_4k_4k_svs"))
237-
sample_wsi_msk = pathlib.Path(remote_sample("wsi2_4k_4k_msk"))
254+
sample_wsi_svs = pathlib.Path(remote_sample("svs-1-small"))
255+
sample_wsi_msk = remote_sample("small_svs_tissue_mask")
256+
sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8)
257+
imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk)
258+
sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg")
238259

239260
save_dir = f"{root_save_dir}/instance/"
240261

@@ -278,48 +299,16 @@ def test_crash_segmentor(remote_sample, tmp_path):
278299
def test_functionality_travis(remote_sample, tmp_path):
279300
"""Functionality test for nuclei instance segmentor."""
280301
root_save_dir = pathlib.Path(tmp_path)
281-
save_dir = pathlib.Path(f"{tmp_path}/output")
282-
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_1k_1k_svs"))
302+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
283303

284304
resolution = 2.0
285305

286-
reader = get_wsireader(mini_wsi_svs)
306+
reader = WSIReader.open(mini_wsi_svs)
287307
thumb = reader.slide_thumbnail(resolution=resolution, units="mpp")
288308
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg"
289309
imwrite(mini_wsi_jpg, thumb)
290310

291-
# resolution for travis testing, not the correct ones
292-
ioconfig = IOSegmentorConfig(
293-
input_resolutions=[{"units": "mpp", "resolution": resolution}],
294-
output_resolutions=[
295-
{"units": "mpp", "resolution": resolution},
296-
{"units": "mpp", "resolution": resolution},
297-
{"units": "mpp", "resolution": resolution},
298-
],
299-
margin=128,
300-
tile_shape=[512, 512],
301-
patch_input_shape=[256, 256],
302-
patch_output_shape=[164, 164],
303-
stride_shape=[164, 164],
304-
)
305-
306311
save_dir = f"{root_save_dir}/instance/"
307-
# * test run on tile, run without worker first
308-
_rm_dir(save_dir)
309-
inst_segmentor = NucleusInstanceSegmentor(
310-
batch_size=1,
311-
num_loader_workers=0,
312-
num_postproc_workers=0,
313-
pretrained_model="hovernet_fast-pannuke",
314-
)
315-
inst_segmentor.predict(
316-
[mini_wsi_jpg],
317-
mode="tile",
318-
ioconfig=ioconfig,
319-
on_gpu=ON_GPU,
320-
crash_on_exception=True,
321-
save_dir=save_dir,
322-
)
323312

324313
# * test run on wsi, test run with worker
325314
# resolution for travis testing, not the correct ones
@@ -337,6 +326,7 @@ def test_functionality_travis(remote_sample, tmp_path):
337326
)
338327

339328
_rm_dir(save_dir)
329+
340330
inst_segmentor = NucleusInstanceSegmentor(
341331
batch_size=1,
342332
num_loader_workers=0,
@@ -359,7 +349,7 @@ def test_functionality_travis(remote_sample, tmp_path):
359349
def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path):
360350
"""Functional tests for merging tile predictions."""
361351
save_dir = pathlib.Path(f"{tmp_path}/output")
362-
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_1k_1k_svs"))
352+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
363353

364354
resolution = 0.5
365355
ioconfig = IOSegmentorConfig(
@@ -439,6 +429,7 @@ def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path):
439429
)
440430

441431
# test exception flag
432+
tile_flag = [0, 0, 0, 0]
442433
with pytest.raises(ValueError, match=r".*Unknown tile mode.*"):
443434
_process_tile_predictions(
444435
ioconfig=ioconfig,
@@ -525,3 +516,97 @@ def test_functionality_local(remote_sample, tmp_path):
525516
score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0)
526517
assert score > 0.9, "Heavy loss of precision!"
527518
_rm_dir(tmp_path)
519+
520+
521+
def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path):
522+
"""Test for nucleus segmentation with IOconfig."""
523+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
524+
output_path = tmp_path / "output"
525+
526+
resolution = 2.0
527+
528+
reader = WSIReader.open(mini_wsi_svs)
529+
thumb = reader.slide_thumbnail(resolution=resolution, units="mpp")
530+
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg"
531+
imwrite(mini_wsi_jpg, thumb)
532+
533+
fetch_pretrained_weights(
534+
"hovernet_fast-pannuke", str(tmp_path.joinpath("hovernet_fast-pannuke.pth"))
535+
)
536+
537+
# resolution for travis testing, not the correct ones
538+
config = {
539+
"input_resolutions": [{"units": "mpp", "resolution": resolution}],
540+
"output_resolutions": [
541+
{"units": "mpp", "resolution": resolution},
542+
{"units": "mpp", "resolution": resolution},
543+
{"units": "mpp", "resolution": resolution},
544+
],
545+
"margin": 128,
546+
"tile_shape": [512, 512],
547+
"patch_input_shape": [256, 256],
548+
"patch_output_shape": [164, 164],
549+
"stride_shape": [164, 164],
550+
"save_resolution": {"units": "mpp", "resolution": 8.0},
551+
}
552+
553+
with open(tmp_path.joinpath("config.yaml"), "w") as fptr:
554+
yaml.dump(config, fptr)
555+
556+
runner = CliRunner()
557+
nucleus_instance_segment_result = runner.invoke(
558+
cli.main,
559+
[
560+
"nucleus-instance-segment",
561+
"--img-input",
562+
str(mini_wsi_jpg),
563+
"--pretrained-weights",
564+
str(tmp_path.joinpath("hovernet_fast-pannuke.pth")),
565+
"--num-loader-workers",
566+
str(0),
567+
"--num-postproc-workers",
568+
str(0),
569+
"--mode",
570+
"tile",
571+
"--output-path",
572+
str(output_path),
573+
"--yaml-config-path",
574+
tmp_path.joinpath("config.yaml"),
575+
],
576+
)
577+
578+
assert nucleus_instance_segment_result.exit_code == 0
579+
assert output_path.joinpath("0.dat").exists()
580+
assert output_path.joinpath("file_map.dat").exists()
581+
assert output_path.joinpath("results.json").exists()
582+
_rm_dir(tmp_path)
583+
584+
585+
def test_cli_nucleus_instance_segment(remote_sample, tmp_path):
586+
"""Test for nucleus segmentation."""
587+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
588+
output_path = tmp_path / "output"
589+
590+
runner = CliRunner()
591+
nucleus_instance_segment_result = runner.invoke(
592+
cli.main,
593+
[
594+
"nucleus-instance-segment",
595+
"--img-input",
596+
str(mini_wsi_svs),
597+
"--mode",
598+
"wsi",
599+
"--num-loader-workers",
600+
str(0),
601+
"--num-postproc-workers",
602+
str(0),
603+
"--output-path",
604+
str(output_path),
605+
],
606+
)
607+
608+
assert nucleus_instance_segment_result.exit_code == 0
609+
assert output_path.joinpath("0.dat").exists()
610+
assert output_path.joinpath("file_map.dat").exists()
611+
assert output_path.joinpath("results.json").exists()
612+
_rm_dir(tmp_path)

tests/models/test_semantic_segmentation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import pathlib
2525
import shutil
2626

27+
# ! The garbage collector
28+
import gc
2729
import numpy as np
2830
import pytest
2931
import torch
@@ -77,8 +79,7 @@ def __init__(self):
7779

7880
def forward(self, img):
7981
"""Define how to use layer."""
80-
output = self.conv(img)
81-
return output
82+
return self.conv(img)
8283

8384
@staticmethod
8485
def infer_batch(model, batch_data, on_gpu):
@@ -136,22 +137,23 @@ def test_segmentor_ioconfig():
136137
patch_output_shape=[1024, 1024],
137138
stride_shape=[512, 512],
138139
)
140+
139141
# error when uniform resolution units are not uniform
140142
with pytest.raises(ValueError, match=r".*Invalid resolution units.*"):
141143
xconfig = copy.deepcopy(default_config)
142144
xconfig["input_resolutions"] = [
143145
{"units": "mpp", "resolution": 0.25},
144146
{"units": "power", "resolution": 0.50},
145147
]
146-
ioconfig = IOSegmentorConfig(**xconfig)
148+
_ = IOSegmentorConfig(**xconfig)
147149
# error when uniform resolution units are not supported
148150
with pytest.raises(ValueError, match=r".*Invalid resolution units.*"):
149151
xconfig = copy.deepcopy(default_config)
150152
xconfig["input_resolutions"] = [
151153
{"units": "alpha", "resolution": 0.25},
152154
{"units": "alpha", "resolution": 0.50},
153155
]
154-
ioconfig = IOSegmentorConfig(**xconfig)
156+
_ = IOSegmentorConfig(**xconfig)
155157

156158
ioconfig = IOSegmentorConfig(
157159
input_resolutions=[
@@ -209,7 +211,8 @@ def test_segmentor_ioconfig():
209211

210212
def test_functional_wsi_stream_dataset(remote_sample):
211213
"""Functional test for WSIStreamDataset."""
212-
mini_wsi_svs = pathlib.Path(remote_sample("wsi2_4k_4k_svs"))
214+
gc.collect() # Force clean up everything on hold
215+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
213216

214217
ioconfig = IOSegmentorConfig(
215218
input_resolutions=[

tiatoolbox/cli/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import click
2626

2727
from tiatoolbox import __version__
28+
from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment
2829
from tiatoolbox.cli.patch_predictor import patch_predictor
2930
from tiatoolbox.cli.read_bounds import read_bounds
3031
from tiatoolbox.cli.save_tiles import save_tiles
@@ -52,15 +53,15 @@ def main():
5253
return 0
5354

5455

55-
main.add_command(slide_info)
56+
main.add_command(nucleus_instance_segment)
57+
main.add_command(patch_predictor)
5658
main.add_command(read_bounds)
57-
main.add_command(slide_thumbnail)
5859
main.add_command(save_tiles)
60+
main.add_command(semantic_segment)
61+
main.add_command(slide_info)
62+
main.add_command(slide_thumbnail)
5963
main.add_command(tissue_mask)
60-
main.add_command(patch_predictor)
6164
main.add_command(stain_norm)
62-
main.add_command(semantic_segment)
63-
6465

6566
if __name__ == "__main__":
6667
sys.exit(main()) # pragma: no cover

0 commit comments

Comments
 (0)