Skip to content

Commit 2e9802b

Browse files
meastyshaneahmed
andauthored
❇️ Add Convert patches Output to AnnotationStore (#718)
- Adds a function `patch_pred_store` to convert the output from `PatchPredictor` into an `AnnotationStore`. --------- Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent 931de99 commit 2e9802b

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

tests/test_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tests.test_annotation_stores import cell_polygon
2020
from tiatoolbox import utils
21+
from tiatoolbox.annotation.storage import SQLiteStore
2122
from tiatoolbox.models.architecture import fetch_pretrained_weights
2223
from tiatoolbox.utils import misc
2324
from tiatoolbox.utils.exceptions import FileNotSupportedError
@@ -734,6 +735,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None:
734735
image = np.ones((10, 10))
735736

736737
def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
738+
"""Dummy read function for tests."""
737739
return np.ones((5, 5))
738740

739741
with pytest.raises(ValueError, match="incorrect size"):
@@ -752,6 +754,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None:
752754
image = np.ones((10, 10))
753755

754756
def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
757+
"""Dummy read function for tests."""
755758
return np.ones((0, 0))
756759

757760
with pytest.raises(ValueError, match="is empty"):
@@ -1642,3 +1645,69 @@ def test_imwrite(tmp_path: Path) -> NoReturn:
16421645
tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg",
16431646
img,
16441647
)
1648+
1649+
1650+
def test_patch_pred_store() -> None:
1651+
"""Test patch_pred_store."""
1652+
# Define a mock patch_output
1653+
patch_output = {
1654+
"predictions": [1, 0, 1],
1655+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1656+
"other": "other",
1657+
}
1658+
1659+
store = misc.patch_pred_store(patch_output, (1.0, 1.0))
1660+
1661+
# Check that its an SQLiteStore containing the expected annotations
1662+
assert isinstance(store, SQLiteStore)
1663+
assert len(store) == 3
1664+
for annotation in store.values():
1665+
assert annotation.geometry.area == 1
1666+
assert annotation.properties["type"] in [0, 1]
1667+
assert "other" not in annotation.properties
1668+
1669+
patch_output.pop("coordinates")
1670+
# check correct error is raised if coordinates are missing
1671+
with pytest.raises(ValueError, match="coordinates"):
1672+
misc.patch_pred_store(patch_output, (1.0, 1.0))
1673+
1674+
1675+
def test_patch_pred_store_cdict() -> None:
1676+
"""Test patch_pred_store with a class dict."""
1677+
# Define a mock patch_output
1678+
patch_output = {
1679+
"predictions": [1, 0, 1],
1680+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1681+
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
1682+
"labels": [1, 0, 1],
1683+
"other": "other",
1684+
}
1685+
class_dict = {0: "class0", 1: "class1"}
1686+
store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict)
1687+
1688+
# Check that its an SQLiteStore containing the expected annotations
1689+
assert isinstance(store, SQLiteStore)
1690+
assert len(store) == 3
1691+
for annotation in store.values():
1692+
assert annotation.geometry.area == 1
1693+
assert annotation.properties["label"] in ["class0", "class1"]
1694+
assert annotation.properties["type"] in ["class0", "class1"]
1695+
assert "other" not in annotation.properties
1696+
1697+
1698+
def test_patch_pred_store_sf() -> None:
1699+
"""Test patch_pred_store with scale factor."""
1700+
# Define a mock patch_output
1701+
patch_output = {
1702+
"predictions": [1, 0, 1],
1703+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1704+
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
1705+
"labels": [1, 0, 1],
1706+
}
1707+
store = misc.patch_pred_store(patch_output, (2.0, 2.0))
1708+
1709+
# Check that its an SQLiteStore containing the expected annotations
1710+
assert isinstance(store, SQLiteStore)
1711+
assert len(store) == 3
1712+
for annotation in store.values():
1713+
assert annotation.geometry.area == 4

tiatoolbox/utils/misc.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import yaml
1919
from filelock import FileLock
2020
from shapely.affinity import translate
21+
from shapely.geometry import Polygon
2122
from shapely.geometry import shape as feature2geometry
2223
from skimage import exposure
2324

@@ -860,7 +861,8 @@ def select_device(*, on_gpu: bool) -> str:
860861
"""Selects the appropriate device as requested.
861862
862863
Args:
863-
on_gpu (bool): Selects gpu if True.
864+
on_gpu (bool):
865+
Selects gpu if True.
864866
865867
Returns:
866868
str:
@@ -883,7 +885,6 @@ def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module:
883885
Returns:
884886
torch.nn.Module:
885887
The model after being moved to cpu/gpu.
886-
887888
"""
888889
if on_gpu: # DataParallel work only for cuda
889890
model = torch.nn.DataParallel(model)
@@ -1194,3 +1195,66 @@ def add_from_dat(
11941195

11951196
logger.info("Added %d annotations.", len(anns))
11961197
store.append_many(anns)
1198+
1199+
1200+
def patch_pred_store(
1201+
patch_output: dict,
1202+
scale_factor: tuple[int, int],
1203+
class_dict: dict | None = None,
1204+
) -> AnnotationStore:
1205+
"""Create an SQLiteStore containing Annotations for each patch.
1206+
1207+
Args:
1208+
patch_output (dict): A dictionary of patch prediction information. Important
1209+
keys are "probabilities", "predictions", "coordinates", and "labels".
1210+
scale_factor (tuple[int, int]): The scale factor to use when loading the
1211+
annotations. All coordinates will be multiplied by this factor to allow
1212+
conversion of annotations saved at non-baseline resolution to baseline.
1213+
Should be model_mpp/slide_mpp.
1214+
class_dict (dict): Optional dictionary mapping class indices to class names.
1215+
1216+
Returns:
1217+
SQLiteStore: An SQLiteStore containing Annotations for each patch.
1218+
1219+
"""
1220+
if "coordinates" not in patch_output:
1221+
# we cant create annotations without coordinates
1222+
msg = "Patch output must contain coordinates."
1223+
raise ValueError(msg)
1224+
# get relevant keys
1225+
class_probs = patch_output.get("probabilities", [])
1226+
preds = patch_output.get("predictions", [])
1227+
patch_coords = np.array(patch_output.get("coordinates", []))
1228+
if not np.all(np.array(scale_factor) == 1):
1229+
patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
1230+
labels = patch_output.get("labels", [])
1231+
# get classes to consider
1232+
if len(class_probs) == 0:
1233+
classes_predicted = np.unique(preds).tolist()
1234+
else:
1235+
classes_predicted = range(len(class_probs[0]))
1236+
if class_dict is None:
1237+
# if no class dict create a default one
1238+
class_dict = {i: i for i in np.unique(preds + labels).tolist()}
1239+
annotations = []
1240+
# find what keys we need to save
1241+
keys = ["predictions"]
1242+
keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]
1243+
1244+
# put patch predictions into a store
1245+
annotations = []
1246+
for i, pred in enumerate(preds):
1247+
if "probabilities" in keys:
1248+
props = {
1249+
f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted
1250+
}
1251+
else:
1252+
props = {}
1253+
if "labels" in keys:
1254+
props["label"] = class_dict[labels[i]]
1255+
props["type"] = class_dict[pred]
1256+
annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props))
1257+
store = SQLiteStore()
1258+
keys = store.append_many(annotations, [str(i) for i in range(len(annotations))])
1259+
1260+
return store

0 commit comments

Comments
 (0)