Skip to content

Commit 11a0acb

Browse files
authored
ENH: Swap pygeos to shapely (#202)
- Replace pygeos with shapely to resolve the GEO binary installation issue. Co-authored-by: @vqdang
1 parent 03f25a9 commit 11a0acb

File tree

9 files changed

+71
-69
lines changed

9 files changed

+71
-69
lines changed

docs/installation.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ Windows
2929
1. Download OpenSlide binaries from `this page <https://openslide.org/download/>`_. Extract the folder and add `bin` and `lib` subdirectories to
3030
Windows `system path <https://docs.microsoft.com/en-us/previous-versions/office/developer/sharepoint-2010/ee537574(v=office.14)>`_.
3131

32-
2. Install
33-
TIAToolbox.
32+
2. Install OpenJPEG. The easiest way is to install OpenJpeg is through conda
33+
using
3434

3535
.. code-block:: console
3636
37-
$ pip install tiatoolbox
37+
C:\> conda install -c conda-forge openjpeg
3838
39-
3. Install OpenJPEG. The easiest way is to install OpenJpeg is through conda
40-
using
39+
3. Install
40+
TIAToolbox.
4141

4242
.. code-block:: console
4343
44-
C:\> conda install -c conda-forge openjpeg pygeos
44+
C:\> pip install tiatoolbox
4545
4646
Linux (Ubuntu)
4747
^^^^^^^^^^^^^^

requirements.conda.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ dependencies:
2323
- pip
2424
- pixman<0.38.0
2525
- python=3.7
26-
- pygeos
2726
- python=3.7
2827
- pytorch
2928
- pyyaml>=5.1

requirements.dev.conda.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ dependencies:
2626
- pip
2727
- pixman<0.38.0
2828
- pre-commit
29-
- pygeos
3029
- pytest-cov
3130
- pytest-runner==5.2
3231
- pytest==6.2.5

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ opencv-python>=4.0
1010
openslide-python==1.1.2
1111
pandas
1212
pillow
13-
pygeos
1413
pyyaml>=5.1
1514
requests
1615
scikit-image

requirements.win64.conda.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ dependencies:
2121
- pillow
2222
- pip
2323
- pixman<0.38.0
24-
- pygeos
2524
- python=3.7
2625
- pytorch
2726
- pyyaml>=5.1

requirements_dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ pandas
1616
pillow
1717
pip>=20.0.2
1818
pre-commit
19-
pygeos
2019
pytest-cov==2.9.0
2120
pytest-runner==5.2
2221
pytest==6.2.5

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"glymur",
2424
"scikit-learn>=0.23.2",
2525
"scikit-image>=0.17",
26+
"shapely",
2627
"torchvision==0.10.1",
2728
"torch==1.9.1",
2829
"tqdm==4.60.0",

tests/models/test_nucleus_instance_segmentor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
"""Tests for Nucleus Instance Segmentor."""
2121

2222
import copy
23+
24+
# ! The garbage collector
25+
import gc
2326
import pathlib
2427
import shutil
2528

@@ -298,6 +301,7 @@ def test_crash_segmentor(remote_sample, tmp_path):
298301

299302
def test_functionality_travis(remote_sample, tmp_path):
300303
"""Functionality test for nuclei instance segmentor."""
304+
gc.collect()
301305
root_save_dir = pathlib.Path(tmp_path)
302306
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
303307

@@ -319,7 +323,7 @@ def test_functionality_travis(remote_sample, tmp_path):
319323
{"units": "mpp", "resolution": resolution},
320324
],
321325
margin=128,
322-
tile_shape=[512, 512],
326+
tile_shape=[1024, 1024],
323327
patch_input_shape=[256, 256],
324328
patch_output_shape=[164, 164],
325329
stride_shape=[164, 164],

tiatoolbox/models/engine/nucleus_instance_segmentor.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
# replace with the sql database once the PR in place
2727
import joblib
2828
import numpy as np
29-
import pygeos
3029
import torch
3130
import tqdm
31+
from shapely.geometry import box as shapely_box
32+
from shapely.strtree import STRtree
3233

3334
from tiatoolbox.models.engine.semantic_segmentor import (
3435
IOSegmentorConfig,
@@ -134,32 +135,31 @@ def _process_tile_predictions(
134135
if len(inst_dict) == 0:
135136
return {}, []
136137

137-
# ! DEPRECATION:
138-
# ! will be deprecated upon finalization of SQL annotation store
138+
# !
139139
m = ioconfig.margin
140140
w, h = tile_shape
141141
inst_boxes = [v["box"] for v in inst_dict.values()]
142142
inst_boxes = np.array(inst_boxes)
143-
tile_rtree = pygeos.STRtree(
144-
pygeos.box(
145-
inst_boxes[:, 0], inst_boxes[:, 1], inst_boxes[:, 2], inst_boxes[:, 3]
146-
)
147-
)
143+
144+
geometries = [shapely_box(*bounds) for bounds in inst_boxes]
145+
# An auxiliary dictionary to actually query the index within the source list
146+
index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}
147+
tile_rtree = STRtree(geometries)
148148
# !
149149

150150
# create margin bounding box, ordering should match with
151151
# created tile info flag (top, bottom, left, right)
152152
boundary_lines = [
153-
pygeos.box(0, 0, w, 1), # noqa top egde
154-
pygeos.box(0, h - 1, w, h), # noqa bottom edge
155-
pygeos.box(0, 0, 1, h), # noqa left
156-
pygeos.box(w - 1, 0, w, h), # noqa right
153+
shapely_box(0, 0, w, 1), # noqa top egde
154+
shapely_box(0, h - 1, w, h), # noqa bottom edge
155+
shapely_box(0, 0, 1, h), # noqa left
156+
shapely_box(w - 1, 0, w, h), # noqa right
157157
]
158158
margin_boxes = [
159-
pygeos.box(0, 0, w, m), # noqa top egde
160-
pygeos.box(0, h - m, w, h), # noqa bottom edge
161-
pygeos.box(0, 0, m, h), # noqa left
162-
pygeos.box(w - m, 0, w, h), # noqa right
159+
shapely_box(0, 0, w, m), # noqa top egde
160+
shapely_box(0, h - m, w, h), # noqa bottom edge
161+
shapely_box(0, 0, m, h), # noqa left
162+
shapely_box(w - m, 0, w, h), # noqa right
163163
]
164164
# ! this is wrt to WSI coord space, not tile
165165
margin_lines = [
@@ -169,7 +169,7 @@ def _process_tile_predictions(
169169
[[w - m, m], [w - m, h - m]], # noqa right
170170
]
171171
margin_lines = np.array(margin_lines) + tile_tl[None, None]
172-
margin_lines = [pygeos.box(*v.flatten().tolist()) for v in margin_lines]
172+
margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines]
173173

174174
# the ids within this match with those within `inst_map`, not UUID
175175
sel_indices = []
@@ -182,8 +182,12 @@ def _process_tile_predictions(
182182
for idx, box in enumerate(margin_boxes)
183183
if tile_flag[idx] or tile_mode == 3
184184
]
185+
185186
sel_indices = [
186-
tile_rtree.query(bounds, predicate="contains") for bounds in sel_boxes
187+
index_by_id[id(geo)]
188+
for bounds in sel_boxes
189+
for geo in tile_rtree.query(bounds)
190+
if bounds.contains(geo)
187191
]
188192
elif tile_mode in [1, 2]:
189193
# for `horizontal/vertical strip` tiles
@@ -196,8 +200,11 @@ def _process_tile_predictions(
196200
margin_boxes[idx] if flag else boundary_lines[idx]
197201
for idx, flag in enumerate(tile_flag)
198202
]
203+
199204
sel_indices = [
200-
tile_rtree.query(bounds, predicate="intersects") for bounds in sel_boxes
205+
index_by_id[id(geo)]
206+
for bounds in sel_boxes
207+
for geo in tile_rtree.query(bounds)
201208
]
202209
else:
203210
raise ValueError(f"Unknown tile mode {tile_mode}.")
@@ -208,7 +215,6 @@ def retrieve_sel_uids(sel_indices, inst_dict):
208215
if len(sel_indices) > 0:
209216
# not sure how costly this is in large dict
210217
inst_uids = list(inst_dict.keys())
211-
sel_indices = [idx for sub_sel in sel_indices for idx in sub_sel]
212218
sel_uids = [inst_uids[idx] for idx in sel_indices]
213219
return sel_uids
214220

@@ -218,26 +224,19 @@ def retrieve_sel_uids(sel_indices, inst_dict):
218224
# this one should contain UUID with the reference database
219225
remove_insts_in_orig = []
220226
if tile_mode == 3:
221-
# ! DEPRECATION:
222-
# ! will be deprecated upon finalization of SQL annotation store
223227
inst_boxes = [v["box"] for v in ref_inst_dict.values()]
224228
inst_boxes = np.array(inst_boxes)
225-
ref_inst_rtree = pygeos.STRtree(
226-
pygeos.box(
227-
inst_boxes[:, 0],
228-
inst_boxes[:, 1],
229-
inst_boxes[:, 2],
230-
inst_boxes[:, 3],
231-
)
232-
)
233-
# !
234229

235-
# remove existing instances in old prediction which intersect
236-
# with the margin lines
230+
geometries = [shapely_box(*bounds) for bounds in inst_boxes]
231+
# An auxiliary dictionary to actually query the index within the source list
232+
index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}
233+
ref_inst_rtree = STRtree(geometries)
237234
sel_indices = [
238-
ref_inst_rtree.query(bounds, predicate="intersects")
235+
index_by_id[id(geo)]
239236
for bounds in margin_lines
237+
for geo in ref_inst_rtree.query(bounds)
240238
]
239+
241240
remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict)
242241

243242
# move inst position from tile space back to WSI space
@@ -390,24 +389,30 @@ def _get_tile_info(
390389
# * === be removed in postproc callback
391390
boxes = tile_outputs
392391

392+
# This saves computation time if the image is smaller than the expected tile
393+
if np.all(image_shape <= tile_shape):
394+
flag = np.zeros([boxes.shape[0], 4], dtype=np.int32)
395+
return [[boxes, flag]]
396+
393397
# * remove all sides for boxes
394398
# unset for those lie within the selection
395399
def unset_removal_flag(boxes, removal_flag):
396400
"""Unset removal flags for tiles intersecting image boundaries."""
397-
# ! DEPRECATION:
398-
# ! will be deprecated upon finalization of SQL annotation store
399401
sel_boxes = [
400-
pygeos.box(0, 0, w, 0), # top edge
401-
pygeos.box(0, h, w, h), # bottom edge
402-
pygeos.box(0, 0, 0, h), # left
403-
pygeos.box(w, 0, w, h), # right
402+
shapely_box(0, 0, w, 0), # top edge
403+
shapely_box(0, h, w, h), # bottom edge
404+
shapely_box(0, 0, 0, h), # left
405+
shapely_box(w, 0, w, h), # right
404406
]
405-
spatial_indexer = pygeos.STRtree(
406-
pygeos.box(boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])
407-
)
408-
# !
407+
geometries = [shapely_box(*bounds) for bounds in boxes]
408+
# An auxiliary dictionary to actually query the index within the source list
409+
index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}
410+
spatial_indexer = STRtree(geometries)
411+
409412
for idx, sel_box in enumerate(sel_boxes):
410-
sel_indices = spatial_indexer.query(sel_box)
413+
sel_indices = [
414+
index_by_id[id(geo)] for geo in spatial_indexer.query(sel_box)
415+
]
411416
removal_flag[sel_indices, idx] = 0
412417
return removal_flag
413418

@@ -621,17 +626,10 @@ def _predict_one_wsi(
621626
patch_inputs = patch_inputs[sel]
622627

623628
# assume to be in [top_left_x, top_left_y, bot_right_x, bot_right_y]
624-
# ! DEPRECATION:
625-
# ! will be deprecated upon finalization of SQL annotation store
626-
spatial_indexer = pygeos.STRtree(
627-
pygeos.box(
628-
patch_outputs[:, 0],
629-
patch_outputs[:, 1],
630-
patch_outputs[:, 2],
631-
patch_outputs[:, 3],
632-
)
633-
)
634-
# !
629+
geometries = [shapely_box(*bounds) for bounds in patch_outputs]
630+
# An auxiliary dictionary to actually query the index within the source list
631+
index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}
632+
spatial_indexer = STRtree(geometries)
635633

636634
# * retrieve tile placement and tile info flag
637635
# tile shape will always be corrected to be multiple of output
@@ -651,7 +649,11 @@ def _predict_one_wsi(
651649

652650
# select any patches that have their output
653651
# within the current tile
654-
sel_indices = spatial_indexer.query(pygeos.box(*tile_bounds))
652+
sel_box = shapely_box(*tile_bounds)
653+
sel_indices = [
654+
index_by_id[id(geo)] for geo in spatial_indexer.query(sel_box)
655+
]
656+
655657
tile_patch_inputs = patch_inputs[sel_indices]
656658
tile_patch_outputs = patch_outputs[sel_indices]
657659
self._to_shared_space(wsi_idx, tile_patch_inputs, tile_patch_outputs)

0 commit comments

Comments
 (0)