Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tiatoolbox/models/architecture/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from scipy.ndimage import measurements
from scipy.ndimage.morphology import binary_fill_holes
from scipy import ndimage
from skimage.morphology import remove_small_objects
from skimage.segmentation import watershed

Expand Down Expand Up @@ -512,7 +511,7 @@ def _proc_np_hv(np_map: np.ndarray, hv_map: np.ndarray, scale_factor: float = 1)
# processing
blb = np.array(blb_raw >= 0.5, dtype=np.int32)

blb = measurements.label(blb)[0]
blb = ndimage.label(blb)[0]
blb = remove_small_objects(blb, min_size=10)
blb[blb > 0] = 1 # background is 0 already

Expand Down Expand Up @@ -573,10 +572,10 @@ def _proc_np_hv(np_map: np.ndarray, hv_map: np.ndarray, scale_factor: float = 1)

marker = blb - overall
marker[marker < 0] = 0
marker = binary_fill_holes(marker).astype("uint8")
marker = ndimage.binary_fill_holes(marker).astype("uint8")
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
marker = measurements.label(marker)[0]
marker = ndimage.label(marker)[0]
marker = remove_small_objects(marker, min_size=obj_size)

return watershed(dist, markers=marker, mask=blb)
Expand Down
47 changes: 40 additions & 7 deletions tiatoolbox/models/architecture/hovernetplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from skimage import morphology

from tiatoolbox.models.architecture.hovernet import HoVerNet
from tiatoolbox.models.architecture.utils import UpSample2x
Expand Down Expand Up @@ -114,8 +115,9 @@ def __init__(
def _proc_ls(ls_map: np.ndarray):
"""Extract Layer Segmentation map with LS Map.

This function takes the layer segmentation map and applies a
gaussian blur to remove spurious segmentations.
This function takes the layer segmentation map and applies various morphological
operations remove spurious segmentations. Note, this processing is specific to
oral epithelium, where prioirty is given to certain tissue layers.

Args:
ls_map:
Expand All @@ -126,10 +128,41 @@ def _proc_ls(ls_map: np.ndarray):
The processed segmentation map.

"""
ls_map = np.squeeze(ls_map.astype("float32"))
ls_map = cv2.GaussianBlur(ls_map, (7, 7), 0)
ls_map = np.around(ls_map)
return ls_map.astype("int")
ls_map = np.squeeze(ls_map)
ls_map = np.around(ls_map).astype("uint8") # ensure all numbers are integers
min_size = 20000
kernel_size = 20

epith_all = np.where(ls_map >= 2, 1, 0).astype("uint8")
mask = np.where(ls_map >= 1, 1, 0).astype("uint8")
epith_all = epith_all > 0
epith_mask = morphology.remove_small_objects(
epith_all, min_size=min_size
).astype("uint8")
epith_edited = epith_mask * ls_map
epith_edited = epith_edited.astype("uint8")
epith_edited_open = np.zeros_like(epith_edited).astype("uint8")
for i in [3, 2, 4]:
tmp = np.where(epith_edited == i, 1, 0).astype("uint8")
ep_open = cv2.morphologyEx(
tmp, cv2.MORPH_CLOSE, np.ones((kernel_size, kernel_size))
)
ep_open = cv2.morphologyEx(
ep_open, cv2.MORPH_OPEN, np.ones((kernel_size, kernel_size))
)
epith_edited_open[ep_open == 1] = i

mask_open = cv2.morphologyEx(
mask, cv2.MORPH_CLOSE, np.ones((kernel_size, kernel_size))
)
mask_open = cv2.morphologyEx(
mask_open, cv2.MORPH_OPEN, np.ones((kernel_size, kernel_size))
).astype("uint8")
ls_map = mask_open.copy()
for i in range(2, 5):
ls_map[epith_edited_open == i] = i

return ls_map.astype("uint8")

@staticmethod
def _get_layer_info(pred_layer):
Expand Down Expand Up @@ -261,7 +294,7 @@ def postproc(raw_maps: List[np.ndarray]):
# fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp

pred_layer = HoVerNetPlus._proc_ls(ls_map)
pred_type = tp_map
pred_type = np.around(tp_map).astype("uint8")

nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)
layer_info_dict = HoVerNetPlus._get_layer_info(pred_layer)
Expand Down