2727import numpy as np
2828import pytest
2929import torch
30+ import yaml
31+ from click .testing import CliRunner
3032
33+ from tiatoolbox import cli
3134from tiatoolbox .models import (
3235 IOSegmentorConfig ,
3336 NucleusInstanceSegmentor ,
3437 SemanticSegmentor ,
3538)
39+ from tiatoolbox .models .architecture import fetch_pretrained_weights
3640from tiatoolbox .models .engine .nucleus_instance_segmentor import (
3741 _process_tile_predictions ,
3842)
3943from tiatoolbox .utils .metrics import f1_detection
4044from tiatoolbox .utils .misc import imwrite
41- from tiatoolbox .wsicore .wsireader import get_wsireader
45+ from tiatoolbox .wsicore .wsireader import WSIReader
4246
43- BATCH_SIZE = 2
47+ BATCH_SIZE = 1
4448ON_TRAVIS = True
4549ON_GPU = not ON_TRAVIS and torch .cuda .is_available ()
4650
@@ -57,11 +61,8 @@ def _crash_func(x):
5761 raise ValueError ("Propataion Crash." )
5862
5963
60- # ----------------------------------------------------
61-
62-
63- def test_get_tile_info ():
64- """Test for getting tile info."""
64+ def helper_tile_info ():
65+ """Helper function for tile information."""
6566 predictor = NucleusInstanceSegmentor (model = "A" )
6667 # ! assuming the tiles organized as follows (coming out from
6768 # ! PatchExtractor). If this is broken, need to check back
@@ -76,7 +77,6 @@ def test_get_tile_info():
7677 # ---------------------
7778 # | 12 | 13 | 14 | 15 |
7879 # ---------------------
79-
8080 # ! assume flag index ordering: left right top bottom
8181 ioconfig = IOSegmentorConfig (
8282 input_resolutions = [{"units" : "mpp" , "resolution" : 0.25 }],
@@ -91,8 +91,17 @@ def test_get_tile_info():
9191 patch_input_shape = [4 , 4 ],
9292 patch_output_shape = [4 , 4 ],
9393 )
94- info = predictor ._get_tile_info ([16 , 16 ], ioconfig )
95- boxes , flag = info [0 ] # index 0 should be full grid, removal
94+
95+ return predictor ._get_tile_info ([16 , 16 ], ioconfig )
96+
97+
98+ # ----------------------------------------------------
99+
100+
101+ def test_get_tile_info ():
102+ """Test for getting tile info."""
103+ info = helper_tile_info ()
104+ _ , flag = info [0 ] # index 0 should be full grid, removal
96105 # removal flag at top edges
97106 assert (
98107 np .sum (
@@ -124,7 +133,10 @@ def test_get_tile_info():
124133 == 0
125134 ), "Fail Right"
126135
127- # test for vertical boundary boxes
136+
137+ def test_vertical_boundary_boxes ():
138+ """Test for vertical boundary boxes."""
139+ info = helper_tile_info ()
128140 _boxes = np .array (
129141 [
130142 [3 , 0 , 5 , 4 ],
@@ -161,7 +173,10 @@ def test_get_tile_info():
161173 assert np .sum (_boxes - boxes ) == 0 , "Wrong Vertical Bounds"
162174 assert np .sum (flag - _flag ) == 0 , "Fail Vertical Flag"
163175
164- # test for horizontal boundary boxes
176+
177+ def test_horizontal_boundary_boxes ():
178+ """Test for horizontal boundary boxes."""
179+ info = helper_tile_info ()
165180 _boxes = np .array (
166181 [
167182 [0 , 3 , 4 , 5 ],
@@ -198,7 +213,10 @@ def test_get_tile_info():
198213 assert np .sum (_boxes - boxes ) == 0 , "Wrong Horizontal Bounds"
199214 assert np .sum (flag - _flag ) == 0 , "Fail Horizontal Flag"
200215
201- # test for cross-section boundary boxes
216+
217+ def test_cross_section_boundary_boxes ():
218+ """Test for cross-section boundary boxes."""
219+ info = helper_tile_info ()
202220 _boxes = np .array (
203221 [
204222 [2 , 2 , 6 , 6 ],
@@ -233,8 +251,11 @@ def test_get_tile_info():
233251def test_crash_segmentor (remote_sample , tmp_path ):
234252 """Test engine crash when given malformed input."""
235253 root_save_dir = pathlib .Path (tmp_path )
236- sample_wsi_svs = pathlib .Path (remote_sample ("wsi2_4k_4k_svs" ))
237- sample_wsi_msk = pathlib .Path (remote_sample ("wsi2_4k_4k_msk" ))
254+ sample_wsi_svs = pathlib .Path (remote_sample ("svs-1-small" ))
255+ sample_wsi_msk = remote_sample ("small_svs_tissue_mask" )
256+ sample_wsi_msk = np .load (sample_wsi_msk ).astype (np .uint8 )
257+ imwrite (f"{ tmp_path } /small_svs_tissue_mask.jpg" , sample_wsi_msk )
258+ sample_wsi_msk = tmp_path .joinpath ("small_svs_tissue_mask.jpg" )
238259
239260 save_dir = f"{ root_save_dir } /instance/"
240261
@@ -278,48 +299,16 @@ def test_crash_segmentor(remote_sample, tmp_path):
278299def test_functionality_travis (remote_sample , tmp_path ):
279300 """Functionality test for nuclei instance segmentor."""
280301 root_save_dir = pathlib .Path (tmp_path )
281- save_dir = pathlib .Path (f"{ tmp_path } /output" )
282- mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_1k_1k_svs" ))
302+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
283303
284304 resolution = 2.0
285305
286- reader = get_wsireader (mini_wsi_svs )
306+ reader = WSIReader . open (mini_wsi_svs )
287307 thumb = reader .slide_thumbnail (resolution = resolution , units = "mpp" )
288308 mini_wsi_jpg = f"{ tmp_path } /mini_svs.jpg"
289309 imwrite (mini_wsi_jpg , thumb )
290310
291- # resolution for travis testing, not the correct ones
292- ioconfig = IOSegmentorConfig (
293- input_resolutions = [{"units" : "mpp" , "resolution" : resolution }],
294- output_resolutions = [
295- {"units" : "mpp" , "resolution" : resolution },
296- {"units" : "mpp" , "resolution" : resolution },
297- {"units" : "mpp" , "resolution" : resolution },
298- ],
299- margin = 128 ,
300- tile_shape = [512 , 512 ],
301- patch_input_shape = [256 , 256 ],
302- patch_output_shape = [164 , 164 ],
303- stride_shape = [164 , 164 ],
304- )
305-
306311 save_dir = f"{ root_save_dir } /instance/"
307- # * test run on tile, run without worker first
308- _rm_dir (save_dir )
309- inst_segmentor = NucleusInstanceSegmentor (
310- batch_size = 1 ,
311- num_loader_workers = 0 ,
312- num_postproc_workers = 0 ,
313- pretrained_model = "hovernet_fast-pannuke" ,
314- )
315- inst_segmentor .predict (
316- [mini_wsi_jpg ],
317- mode = "tile" ,
318- ioconfig = ioconfig ,
319- on_gpu = ON_GPU ,
320- crash_on_exception = True ,
321- save_dir = save_dir ,
322- )
323312
324313 # * test run on wsi, test run with worker
325314 # resolution for travis testing, not the correct ones
@@ -337,6 +326,7 @@ def test_functionality_travis(remote_sample, tmp_path):
337326 )
338327
339328 _rm_dir (save_dir )
329+
340330 inst_segmentor = NucleusInstanceSegmentor (
341331 batch_size = 1 ,
342332 num_loader_workers = 0 ,
@@ -359,7 +349,7 @@ def test_functionality_travis(remote_sample, tmp_path):
359349def test_functionality_merge_tile_predictions_travis (remote_sample , tmp_path ):
360350 """Functional tests for merging tile predictions."""
361351 save_dir = pathlib .Path (f"{ tmp_path } /output" )
362- mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_1k_1k_svs " ))
352+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs " ))
363353
364354 resolution = 0.5
365355 ioconfig = IOSegmentorConfig (
@@ -439,6 +429,7 @@ def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path):
439429 )
440430
441431 # test exception flag
432+ tile_flag = [0 , 0 , 0 , 0 ]
442433 with pytest .raises (ValueError , match = r".*Unknown tile mode.*" ):
443434 _process_tile_predictions (
444435 ioconfig = ioconfig ,
@@ -525,3 +516,97 @@ def test_functionality_local(remote_sample, tmp_path):
525516 score = f1_detection (inst_coords_b , inst_coords_a , radius = 1.0 )
526517 assert score > 0.9 , "Heavy loss of precision!"
527518 _rm_dir (tmp_path )
519+
520+
521+ def test_cli_nucleus_instance_segment_ioconfig (remote_sample , tmp_path ):
522+ """Test for nucleus segmentation with IOconfig."""
523+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
524+ output_path = tmp_path / "output"
525+
526+ resolution = 2.0
527+
528+ reader = WSIReader .open (mini_wsi_svs )
529+ thumb = reader .slide_thumbnail (resolution = resolution , units = "mpp" )
530+ mini_wsi_jpg = f"{ tmp_path } /mini_svs.jpg"
531+ imwrite (mini_wsi_jpg , thumb )
532+
533+ fetch_pretrained_weights (
534+ "hovernet_fast-pannuke" , str (tmp_path .joinpath ("hovernet_fast-pannuke.pth" ))
535+ )
536+
537+ # resolution for travis testing, not the correct ones
538+ config = {
539+ "input_resolutions" : [{"units" : "mpp" , "resolution" : resolution }],
540+ "output_resolutions" : [
541+ {"units" : "mpp" , "resolution" : resolution },
542+ {"units" : "mpp" , "resolution" : resolution },
543+ {"units" : "mpp" , "resolution" : resolution },
544+ ],
545+ "margin" : 128 ,
546+ "tile_shape" : [512 , 512 ],
547+ "patch_input_shape" : [256 , 256 ],
548+ "patch_output_shape" : [164 , 164 ],
549+ "stride_shape" : [164 , 164 ],
550+ "save_resolution" : {"units" : "mpp" , "resolution" : 8.0 },
551+ }
552+
553+ with open (tmp_path .joinpath ("config.yaml" ), "w" ) as fptr :
554+ yaml .dump (config , fptr )
555+
556+ runner = CliRunner ()
557+ nucleus_instance_segment_result = runner .invoke (
558+ cli .main ,
559+ [
560+ "nucleus-instance-segment" ,
561+ "--img-input" ,
562+ str (mini_wsi_jpg ),
563+ "--pretrained-weights" ,
564+ str (tmp_path .joinpath ("hovernet_fast-pannuke.pth" )),
565+ "--num-loader-workers" ,
566+ str (0 ),
567+ "--num-postproc-workers" ,
568+ str (0 ),
569+ "--mode" ,
570+ "tile" ,
571+ "--output-path" ,
572+ str (output_path ),
573+ "--yaml-config-path" ,
574+ tmp_path .joinpath ("config.yaml" ),
575+ ],
576+ )
577+
578+ assert nucleus_instance_segment_result .exit_code == 0
579+ assert output_path .joinpath ("0.dat" ).exists ()
580+ assert output_path .joinpath ("file_map.dat" ).exists ()
581+ assert output_path .joinpath ("results.json" ).exists ()
582+ _rm_dir (tmp_path )
583+
584+
585+ def test_cli_nucleus_instance_segment (remote_sample , tmp_path ):
586+ """Test for nucleus segmentation."""
587+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
588+ output_path = tmp_path / "output"
589+
590+ runner = CliRunner ()
591+ nucleus_instance_segment_result = runner .invoke (
592+ cli .main ,
593+ [
594+ "nucleus-instance-segment" ,
595+ "--img-input" ,
596+ str (mini_wsi_svs ),
597+ "--mode" ,
598+ "wsi" ,
599+ "--num-loader-workers" ,
600+ str (0 ),
601+ "--num-postproc-workers" ,
602+ str (0 ),
603+ "--output-path" ,
604+ str (output_path ),
605+ ],
606+ )
607+
608+ assert nucleus_instance_segment_result .exit_code == 0
609+ assert output_path .joinpath ("0.dat" ).exists ()
610+ assert output_path .joinpath ("file_map.dat" ).exists ()
611+ assert output_path .joinpath ("results.json" ).exists ()
612+ _rm_dir (tmp_path )
0 commit comments