Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Engine

- :obj:`Patch Prediction <tiatoolbox.models.engine.patch_predictor.PatchPredictor>`
- :obj:`Semantic Segmentation <tiatoolbox.models.engine.semantic_segmentor.SemanticSegmentor>`
- :obj:`Feature Extraction <tiatoolbox.models.engine.semantic_segmentor.FeatureExtractor>`
- :obj:`Feature Extraction <tiatoolbox.models.engine.semantic_segmentor.DeepFeatureExtractor>`
- :obj:`Nucleus Instance Segmnetation <tiatoolbox.models.engine.nucleus_instance_segmentor.NucleusInstanceSegmentor>`

----------------------------
Expand Down
10 changes: 5 additions & 5 deletions examples/full-pipelines/slide-graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@
"metadata": {},
"outputs": [],
"source": [
"from tiatoolbox.models import FeatureExtractor, IOSegmentorConfig\n",
"from tiatoolbox.models.architecture import CNNExtractor\n",
"from tiatoolbox.models import DeepFeatureExtractor, IOSegmentorConfig\n",
"from tiatoolbox.models.architecture import CNNBackbone\n",
"\n",
"\n",
"def extract_deep_features(\n",
Expand All @@ -553,8 +553,8 @@
" stride_shape=[512, 512],\n",
" save_resolution={\"units\": \"mpp\", \"resolution\": 8.0},\n",
" )\n",
" model = CNNExtractor(\"resnet50\")\n",
" extractor = FeatureExtractor(\n",
" model = CNNBackbone(\"resnet50\")\n",
" extractor = DeepFeatureExtractor(\n",
" batch_size=16, model=model, num_loader_workers=4)\n",
" # Injecting customized preprocessing functions,\n",
" # check the document or sample code below for API.\n",
Expand Down Expand Up @@ -603,7 +603,7 @@
"and count the nuclei of each type in each patch.\n",
"We encapsulate this process in the function `get_composition_features`.\n",
"\n",
"Unlike the `FeatureExtractor` above, the `NucleusInstanceSegmentor` engine\n",
"Unlike the `DeepFeatureExtractor` above, the `NucleusInstanceSegmentor` engine\n",
"returns a single output file when given a single WSI input. Their corresponding\n",
"output files are named as `['*/0.dat', '*/1.dat', etc.]` and we need to rename\n",
"them accordingly. We generate the cell composition features from each\n",
Expand Down
12 changes: 6 additions & 6 deletions examples/inference-pipelines/slide-graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@
"metadata": {},
"outputs": [],
"source": [
"from tiatoolbox.models import FeatureExtractor, IOSegmentorConfig\n",
"from tiatoolbox.models.architecture import CNNExtractor\n",
"from tiatoolbox.models import DeepFeatureExtractor, IOSegmentorConfig\n",
"from tiatoolbox.models.architecture import CNNBackbone\n",
"\n",
"\n",
"def extract_deep_features(\n",
Expand All @@ -464,8 +464,8 @@
" stride_shape=[512, 512],\n",
" save_resolution={\"units\": \"mpp\", \"resolution\": 8.0},\n",
" )\n",
" model = CNNExtractor(\"resnet50\")\n",
" extractor = FeatureExtractor(\n",
" model = CNNBackbone(\"resnet50\")\n",
" extractor = DeepFeatureExtractor(\n",
" batch_size=32, model=model, num_loader_workers=4)\n",
" # Injecting customized preprocessing functions,\n",
" # check the document or sample code below for API.\n",
Expand Down Expand Up @@ -514,10 +514,10 @@
"and count the nuclei of each type in each patch.\n",
"We encapsulate this process in the function `get_composition_features`.\n",
"\n",
"Unlike the `FeatureExtractor` above, the `NucleusInstanceSegmentor` engine\n",
"Unlike the `DeepFeatureExtractor` above, the `NucleusInstanceSegmentor` engine\n",
"returns a single output file given a single WSI input. The corresponding output\n",
"files are named as `['*/0.dat', '*/1.dat', etc.]`. Each of these `.dat` files is used to generate\n",
"two files named `*.features.npy` and `*.position.npy`. As in the case of `FeatureExtractor`,\n",
"two files named `*.features.npy` and `*.position.npy`. As in the case of `DeepFeatureExtractor`,\n",
"the wildcard _* is, by default, replaced by sequentially ordered names,\n",
"for easier management and to avoid inadvertent overwriting."
]
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import numpy as np
import torch

from tiatoolbox.models.architecture.vanilla import CNNExtractor
from tiatoolbox.models.architecture.vanilla import CNNBackbone
from tiatoolbox.models.engine.semantic_segmentor import (
FeatureExtractor,
DeepFeatureExtractor,
IOSegmentorConfig,
)
from tiatoolbox.wsicore.wsireader import get_wsireader
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_functional(remote_sample, tmp_path):

# * test providing pretrained from torch vs pretrained_model.yaml
_rm_dir(save_dir) # default output dir test
extractor = FeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask")
extractor = DeepFeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask")
output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
Expand Down Expand Up @@ -87,8 +87,8 @@ def test_functional(remote_sample, tmp_path):
save_resolution={"units": "mpp", "resolution": 8.0},
)

model = CNNExtractor("resnet50")
extractor = FeatureExtractor(batch_size=4, model=model)
model = CNNBackbone("resnet50")
extractor = DeepFeatureExtractor(batch_size=4, model=model)
# should still run because we skip exception
output_list = extractor.predict(
[mini_wsi_svs],
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
WSIPatchDataset,
)
from tiatoolbox.models.engine.semantic_segmentor import (
FeatureExtractor,
DeepFeatureExtractor,
IOSegmentorConfig,
SemanticSegmentor,
WSIStreamDataset,
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch

from tiatoolbox import rcParam
from tiatoolbox.models.architecture.vanilla import CNNExtractor, CNNModel
from tiatoolbox.models.architecture.vanilla import CNNBackbone, CNNModel
from tiatoolbox.models.dataset.classification import predefined_preproc_func
from tiatoolbox.utils.misc import download_data

Expand Down
4 changes: 2 additions & 2 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def infer_batch(model, batch_data, on_gpu):
return output.cpu().numpy()


class CNNExtractor(ModelABC):
class CNNBackbone(ModelABC):
"""Retrieve the model backbone and strip the classification layer.

This is a wrapper for pretrained models within pytorch.
Expand Down Expand Up @@ -193,7 +193,7 @@ class CNNExtractor(ModelABC):
>>> # Creating resnet50 architecture from default pytorch
>>> # without the classification layer with its associated
>>> # weights loaded
>>> model = CNNExtractor(backbone="resnet50")
>>> model = CNNBackbone(backbone="resnet50")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.random.rand(4, 3, 512, 512)
Expand Down
16 changes: 8 additions & 8 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def predict(
return outputs


class FeatureExtractor(SemanticSegmentor):
class DeepFeatureExtractor(SemanticSegmentor):
"""Generic CNN Feature Extractor.

A engine for using any CNN model as a feature extractor.
Expand All @@ -1131,7 +1131,7 @@ class FeatureExtractor(SemanticSegmentor):
for processing the data. By default, the corresponding pretrained weights
will also be downloaded. However, you can override with your own set of
weights via the `pretrained_weights` argument. Argument is case insensitive.
Refer to :class:`tiatoolbox.models.architecture.vanilla.CNNExtractor`
Refer to :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone`
for list of supported pretrained models.
pretrained_weights (str): Path to the weight of the corresponding
`pretrained_model`.
Expand All @@ -1148,11 +1148,11 @@ class FeatureExtractor(SemanticSegmentor):

Examples:
>>> # Sample output of a network
>>> from tiatoolbox.models.architecture.vanilla import CNNExtractor
>>> from tiatoolbox.models.architecture.vanilla import CNNBackbone
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
>>> # create resnet50 with pytorch pretrained weights
>>> model = CNNExtractor('resnet50')
>>> predictor = FeatureExtractor(model=model)
>>> model = CNNBackbone('resnet50')
>>> predictor = DeepFeatureExtractor(model=model)
>>> output = predictor.predict(wsis, mode='wsi')
>>> list(output.keys())
[('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')]
Expand Down Expand Up @@ -1296,11 +1296,11 @@ def predict(

Examples:
>>> # Sample output of a network
>>> from tiatoolbox.models.architecture.vanilla import CNNExtractor
>>> from tiatoolbox.models.architecture.vanilla import CNNBackbone
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
>>> # create resnet50 with pytorch pretrained weights
>>> model = CNNExtractor('resnet50')
>>> predictor = FeatureExtractor(model=model)
>>> model = CNNBackbone('resnet50')
>>> predictor = DeepFeatureExtractor(model=model)
>>> output = predictor.predict(wsis, mode='wsi')
>>> list(output.keys())
[('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')]
Expand Down