2626# replace with the sql database once the PR in place
2727import joblib
2828import numpy as np
29- import pygeos
3029import torch
3130import tqdm
31+ from shapely .geometry import box as shapely_box
32+ from shapely .strtree import STRtree
3233
3334from tiatoolbox .models .engine .semantic_segmentor import (
3435 IOSegmentorConfig ,
@@ -134,32 +135,31 @@ def _process_tile_predictions(
134135 if len (inst_dict ) == 0 :
135136 return {}, []
136137
137- # ! DEPRECATION:
138- # ! will be deprecated upon finalization of SQL annotation store
138+ # !
139139 m = ioconfig .margin
140140 w , h = tile_shape
141141 inst_boxes = [v ["box" ] for v in inst_dict .values ()]
142142 inst_boxes = np .array (inst_boxes )
143- tile_rtree = pygeos . STRtree (
144- pygeos . box (
145- inst_boxes [:, 0 ], inst_boxes [:, 1 ], inst_boxes [:, 2 ], inst_boxes [:, 3 ]
146- )
147- )
143+
144+ geometries = [ shapely_box ( * bounds ) for bounds in inst_boxes ]
145+ # An auxiliary dictionary to actually query the index within the source list
146+ index_by_id = { id ( geo ): idx for idx , geo in enumerate ( geometries )}
147+ tile_rtree = STRtree ( geometries )
148148 # !
149149
150150 # create margin bounding box, ordering should match with
151151 # created tile info flag (top, bottom, left, right)
152152 boundary_lines = [
153- pygeos . box (0 , 0 , w , 1 ), # noqa top egde
154- pygeos . box (0 , h - 1 , w , h ), # noqa bottom edge
155- pygeos . box (0 , 0 , 1 , h ), # noqa left
156- pygeos . box (w - 1 , 0 , w , h ), # noqa right
153+ shapely_box (0 , 0 , w , 1 ), # noqa top egde
154+ shapely_box (0 , h - 1 , w , h ), # noqa bottom edge
155+ shapely_box (0 , 0 , 1 , h ), # noqa left
156+ shapely_box (w - 1 , 0 , w , h ), # noqa right
157157 ]
158158 margin_boxes = [
159- pygeos . box (0 , 0 , w , m ), # noqa top egde
160- pygeos . box (0 , h - m , w , h ), # noqa bottom edge
161- pygeos . box (0 , 0 , m , h ), # noqa left
162- pygeos . box (w - m , 0 , w , h ), # noqa right
159+ shapely_box (0 , 0 , w , m ), # noqa top egde
160+ shapely_box (0 , h - m , w , h ), # noqa bottom edge
161+ shapely_box (0 , 0 , m , h ), # noqa left
162+ shapely_box (w - m , 0 , w , h ), # noqa right
163163 ]
164164 # ! this is wrt to WSI coord space, not tile
165165 margin_lines = [
@@ -169,7 +169,7 @@ def _process_tile_predictions(
169169 [[w - m , m ], [w - m , h - m ]], # noqa right
170170 ]
171171 margin_lines = np .array (margin_lines ) + tile_tl [None , None ]
172- margin_lines = [pygeos . box (* v .flatten ().tolist ()) for v in margin_lines ]
172+ margin_lines = [shapely_box (* v .flatten ().tolist ()) for v in margin_lines ]
173173
174174 # the ids within this match with those within `inst_map`, not UUID
175175 sel_indices = []
@@ -182,8 +182,12 @@ def _process_tile_predictions(
182182 for idx , box in enumerate (margin_boxes )
183183 if tile_flag [idx ] or tile_mode == 3
184184 ]
185+
185186 sel_indices = [
186- tile_rtree .query (bounds , predicate = "contains" ) for bounds in sel_boxes
187+ index_by_id [id (geo )]
188+ for bounds in sel_boxes
189+ for geo in tile_rtree .query (bounds )
190+ if bounds .contains (geo )
187191 ]
188192 elif tile_mode in [1 , 2 ]:
189193 # for `horizontal/vertical strip` tiles
@@ -196,8 +200,11 @@ def _process_tile_predictions(
196200 margin_boxes [idx ] if flag else boundary_lines [idx ]
197201 for idx , flag in enumerate (tile_flag )
198202 ]
203+
199204 sel_indices = [
200- tile_rtree .query (bounds , predicate = "intersects" ) for bounds in sel_boxes
205+ index_by_id [id (geo )]
206+ for bounds in sel_boxes
207+ for geo in tile_rtree .query (bounds )
201208 ]
202209 else :
203210 raise ValueError (f"Unknown tile mode { tile_mode } ." )
@@ -208,7 +215,6 @@ def retrieve_sel_uids(sel_indices, inst_dict):
208215 if len (sel_indices ) > 0 :
209216 # not sure how costly this is in large dict
210217 inst_uids = list (inst_dict .keys ())
211- sel_indices = [idx for sub_sel in sel_indices for idx in sub_sel ]
212218 sel_uids = [inst_uids [idx ] for idx in sel_indices ]
213219 return sel_uids
214220
@@ -218,26 +224,19 @@ def retrieve_sel_uids(sel_indices, inst_dict):
218224 # this one should contain UUID with the reference database
219225 remove_insts_in_orig = []
220226 if tile_mode == 3 :
221- # ! DEPRECATION:
222- # ! will be deprecated upon finalization of SQL annotation store
223227 inst_boxes = [v ["box" ] for v in ref_inst_dict .values ()]
224228 inst_boxes = np .array (inst_boxes )
225- ref_inst_rtree = pygeos .STRtree (
226- pygeos .box (
227- inst_boxes [:, 0 ],
228- inst_boxes [:, 1 ],
229- inst_boxes [:, 2 ],
230- inst_boxes [:, 3 ],
231- )
232- )
233- # !
234229
235- # remove existing instances in old prediction which intersect
236- # with the margin lines
230+ geometries = [shapely_box (* bounds ) for bounds in inst_boxes ]
231+ # An auxiliary dictionary to actually query the index within the source list
232+ index_by_id = {id (geo ): idx for idx , geo in enumerate (geometries )}
233+ ref_inst_rtree = STRtree (geometries )
237234 sel_indices = [
238- ref_inst_rtree . query ( bounds , predicate = "intersects" )
235+ index_by_id [ id ( geo )]
239236 for bounds in margin_lines
237+ for geo in ref_inst_rtree .query (bounds )
240238 ]
239+
241240 remove_insts_in_orig = retrieve_sel_uids (sel_indices , ref_inst_dict )
242241
243242 # move inst position from tile space back to WSI space
@@ -390,24 +389,30 @@ def _get_tile_info(
390389 # * === be removed in postproc callback
391390 boxes = tile_outputs
392391
392+ # This saves computation time if the image is smaller than the expected tile
393+ if np .all (image_shape <= tile_shape ):
394+ flag = np .zeros ([boxes .shape [0 ], 4 ], dtype = np .int32 )
395+ return [[boxes , flag ]]
396+
393397 # * remove all sides for boxes
394398 # unset for those lie within the selection
395399 def unset_removal_flag (boxes , removal_flag ):
396400 """Unset removal flags for tiles intersecting image boundaries."""
397- # ! DEPRECATION:
398- # ! will be deprecated upon finalization of SQL annotation store
399401 sel_boxes = [
400- pygeos . box (0 , 0 , w , 0 ), # top edge
401- pygeos . box (0 , h , w , h ), # bottom edge
402- pygeos . box (0 , 0 , 0 , h ), # left
403- pygeos . box (w , 0 , w , h ), # right
402+ shapely_box (0 , 0 , w , 0 ), # top edge
403+ shapely_box (0 , h , w , h ), # bottom edge
404+ shapely_box (0 , 0 , 0 , h ), # left
405+ shapely_box (w , 0 , w , h ), # right
404406 ]
405- spatial_indexer = pygeos .STRtree (
406- pygeos .box (boxes [:, 0 ], boxes [:, 1 ], boxes [:, 2 ], boxes [:, 3 ])
407- )
408- # !
407+ geometries = [shapely_box (* bounds ) for bounds in boxes ]
408+ # An auxiliary dictionary to actually query the index within the source list
409+ index_by_id = {id (geo ): idx for idx , geo in enumerate (geometries )}
410+ spatial_indexer = STRtree (geometries )
411+
409412 for idx , sel_box in enumerate (sel_boxes ):
410- sel_indices = spatial_indexer .query (sel_box )
413+ sel_indices = [
414+ index_by_id [id (geo )] for geo in spatial_indexer .query (sel_box )
415+ ]
411416 removal_flag [sel_indices , idx ] = 0
412417 return removal_flag
413418
@@ -621,17 +626,10 @@ def _predict_one_wsi(
621626 patch_inputs = patch_inputs [sel ]
622627
623628 # assume to be in [top_left_x, top_left_y, bot_right_x, bot_right_y]
624- # ! DEPRECATION:
625- # ! will be deprecated upon finalization of SQL annotation store
626- spatial_indexer = pygeos .STRtree (
627- pygeos .box (
628- patch_outputs [:, 0 ],
629- patch_outputs [:, 1 ],
630- patch_outputs [:, 2 ],
631- patch_outputs [:, 3 ],
632- )
633- )
634- # !
629+ geometries = [shapely_box (* bounds ) for bounds in patch_outputs ]
630+ # An auxiliary dictionary to actually query the index within the source list
631+ index_by_id = {id (geo ): idx for idx , geo in enumerate (geometries )}
632+ spatial_indexer = STRtree (geometries )
635633
636634 # * retrieve tile placement and tile info flag
637635 # tile shape will always be corrected to be multiple of output
@@ -651,7 +649,11 @@ def _predict_one_wsi(
651649
652650 # select any patches that have their output
653651 # within the current tile
654- sel_indices = spatial_indexer .query (pygeos .box (* tile_bounds ))
652+ sel_box = shapely_box (* tile_bounds )
653+ sel_indices = [
654+ index_by_id [id (geo )] for geo in spatial_indexer .query (sel_box )
655+ ]
656+
655657 tile_patch_inputs = patch_inputs [sel_indices ]
656658 tile_patch_outputs = patch_outputs [sel_indices ]
657659 self ._to_shared_space (wsi_idx , tile_patch_inputs , tile_patch_outputs )
0 commit comments