Skip to content

Commit 49a2baf

Browse files
simongrahamvqdangshaneahmed
authored
NEW: Add PCam Patch Classification Models (#225)
- Add patch classification models trained on the [Patch Camelyon dataset](https://github.com/basveeling/pcam). We include all models that were used for the Kather100K dataset. - As part of this PR, we : - Add new model links - Add tests for model output - Add links to 2 new sample patches for testing Co-authored-by: Simon Graham <[email protected]> Co-authored-by: Dang Vu <[email protected]> Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent ce5f13d commit 49a2baf

File tree

6 files changed

+334
-37
lines changed

6 files changed

+334
-37
lines changed

tests/conftest.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def patch_extr_svs_npy_read(remote_sample) -> pathlib.Path:
299299
@pytest.fixture(scope="session")
300300
def sample_patch1(remote_sample) -> pathlib.Path:
301301
"""Sample pytest fixture for sample patch 1.
302-
Download sample patch 1 for pytest.
302+
Download sample patch 1 (Kather100K) for pytest.
303303
304304
"""
305305
return remote_sample("sample-patch-1")
@@ -308,12 +308,30 @@ def sample_patch1(remote_sample) -> pathlib.Path:
308308
@pytest.fixture(scope="session")
309309
def sample_patch2(remote_sample) -> pathlib.Path:
310310
"""Sample pytest fixture for sample patch 2.
311-
Download sample patch 2 for pytest.
311+
Download sample patch 2 (Kather100K) for pytest.
312312
313313
"""
314314
return remote_sample("sample-patch-2")
315315

316316

317+
@pytest.fixture(scope="session")
318+
def sample_patch3(remote_sample) -> pathlib.Path:
319+
"""Sample pytest fixture for sample patch 3.
320+
Download sample patch 3 (PCam) for pytest.
321+
322+
"""
323+
return remote_sample("sample-patch-3")
324+
325+
326+
@pytest.fixture(scope="session")
327+
def sample_patch4(remote_sample) -> pathlib.Path:
328+
"""Sample pytest fixture for sample patch 4.
329+
Download sample patch 4 (PCam) for pytest.
330+
331+
"""
332+
return remote_sample("sample-patch-4")
333+
334+
317335
@pytest.fixture(scope="session")
318336
def dir_sample_patches(sample_patch1, sample_patch2, tmpdir_factory):
319337
"""Directory of sample image patches for testing."""

tests/models/test_patch_predictor.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,8 @@ def _test_predictor_output(
937937
)
938938

939939

940-
def test_patch_predictor_output(sample_patch1, sample_patch2):
941-
"""Test the output of patch prediction models."""
940+
def test_patch_predictor_kather100k_output(sample_patch1, sample_patch2):
941+
"""Test the output of patch prediction models on Kather100K dataset."""
942942
inputs = [pathlib.Path(sample_patch1), pathlib.Path(sample_patch2)]
943943
pretrained_info = {
944944
"alexnet-kather100k": [1.0, 0.9999735355377197],
@@ -972,6 +972,41 @@ def test_patch_predictor_output(sample_patch1, sample_patch2):
972972
break
973973

974974

975+
def test_patch_predictor_pcam_output(sample_patch3, sample_patch4):
976+
"""Test the output of patch prediction models on PCam dataset."""
977+
inputs = [pathlib.Path(sample_patch3), pathlib.Path(sample_patch4)]
978+
pretrained_info = {
979+
"alexnet-pcam": [0.999980092048645, 0.9769067168235779],
980+
"resnet18-pcam": [0.999992847442627, 0.9466130137443542],
981+
"resnet34-pcam": [1.0, 0.9976525902748108],
982+
"resnet50-pcam": [0.9999270439147949, 0.9999996423721313],
983+
"resnet101-pcam": [1.0, 0.9997289776802063],
984+
"resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802],
985+
"resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244],
986+
"wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799],
987+
"wide_resnet101_2-pcam": [1.0, 0.9945427179336548],
988+
"densenet121-pcam": [0.9999251365661621, 0.9997479319572449],
989+
"densenet161-pcam": [0.9999969005584717, 0.9662821292877197],
990+
"densenet169-pcam": [0.9999998807907104, 0.9993504881858826],
991+
"densenet201-pcam": [0.9999942779541016, 0.9950824975967407],
992+
"mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986],
993+
"mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085],
994+
"mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972],
995+
"googlenet-pcam": [0.9999929666519165, 0.8701475858688354],
996+
}
997+
for pretrained_model, expected_prob in pretrained_info.items():
998+
_test_predictor_output(
999+
inputs,
1000+
pretrained_model,
1001+
probabilities_check=expected_prob,
1002+
predictions_check=[1, 0],
1003+
on_gpu=ON_GPU,
1004+
)
1005+
# only test 1 on travis to limit runtime
1006+
if ON_TRAVIS:
1007+
break
1008+
1009+
9751010
# -------------------------------------------------------------------------------------
9761011
# Command Line Interface
9771012
# -------------------------------------------------------------------------------------

tiatoolbox/data/pretrained_model.yaml

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,245 @@ googlenet-kather100k:
237237
input_resolutions: [{"resolution": 0.5, "units": "mpp"}]
238238
dataset: kather100k
239239

240+
alexnet-pcam:
241+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/alexnet-pcam.pth
242+
architecture:
243+
class: vanilla.CNNModel
244+
kwargs:
245+
backbone: alexnet
246+
num_classes: 2
247+
ioconfig:
248+
class: patch_predictor.IOPatchPredictorConfig
249+
kwargs:
250+
patch_input_shape: [96, 96]
251+
stride_shape: [96, 96]
252+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
253+
dataset: pcam
254+
resnet18-pcam:
255+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-pcam.pth
256+
architecture:
257+
class: vanilla.CNNModel
258+
kwargs:
259+
backbone: resnet18
260+
num_classes: 2
261+
ioconfig:
262+
class: patch_predictor.IOPatchPredictorConfig
263+
kwargs:
264+
patch_input_shape: [96, 96]
265+
stride_shape: [96, 96]
266+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
267+
dataset: pcam
268+
resnet34-pcam:
269+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet34-pcam.pth
270+
architecture:
271+
class: vanilla.CNNModel
272+
kwargs:
273+
backbone: resnet34
274+
num_classes: 2
275+
ioconfig:
276+
class: patch_predictor.IOPatchPredictorConfig
277+
kwargs:
278+
patch_input_shape: [96, 96]
279+
stride_shape: [96, 96]
280+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
281+
dataset: pcam
282+
resnet50-pcam:
283+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet50-pcam.pth
284+
architecture:
285+
class: vanilla.CNNModel
286+
kwargs:
287+
backbone: resnet50
288+
num_classes: 2
289+
ioconfig:
290+
class: patch_predictor.IOPatchPredictorConfig
291+
kwargs:
292+
patch_input_shape: [96, 96]
293+
stride_shape: [96, 96]
294+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
295+
dataset: pcam
296+
resnet101-pcam:
297+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet101-pcam.pth
298+
architecture:
299+
class: vanilla.CNNModel
300+
kwargs:
301+
backbone: resnet101
302+
num_classes: 2
303+
ioconfig:
304+
class: patch_predictor.IOPatchPredictorConfig
305+
kwargs:
306+
patch_input_shape: [96, 96]
307+
stride_shape: [96, 96]
308+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
309+
dataset: pcam
310+
resnext50_32x4d-pcam:
311+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnext50_32x4d-pcam.pth
312+
architecture:
313+
class: vanilla.CNNModel
314+
kwargs:
315+
backbone: resnext50_32x4d
316+
num_classes: 2
317+
ioconfig:
318+
class: patch_predictor.IOPatchPredictorConfig
319+
kwargs:
320+
patch_input_shape: [96, 96]
321+
stride_shape: [96, 96]
322+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
323+
dataset: pcam
324+
resnext101_32x8d-pcam:
325+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnext101_32x8d-pcam.pth
326+
architecture:
327+
class: vanilla.CNNModel
328+
kwargs:
329+
backbone: resnext101_32x8d
330+
num_classes: 2
331+
ioconfig:
332+
class: patch_predictor.IOPatchPredictorConfig
333+
kwargs:
334+
patch_input_shape: [96, 96]
335+
stride_shape: [96, 96]
336+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
337+
dataset: pcam
338+
wide_resnet50_2-pcam:
339+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/wide_resnet50_2-pcam.pth
340+
architecture:
341+
class: vanilla.CNNModel
342+
kwargs:
343+
backbone: wide_resnet50_2
344+
num_classes: 2
345+
ioconfig:
346+
class: patch_predictor.IOPatchPredictorConfig
347+
kwargs:
348+
patch_input_shape: [96, 96]
349+
stride_shape: [96, 96]
350+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
351+
dataset: pcam
352+
wide_resnet101_2-pcam:
353+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/wide_resnet101_2-pcam.pth
354+
architecture:
355+
class: vanilla.CNNModel
356+
kwargs:
357+
backbone: wide_resnet101_2
358+
num_classes: 2
359+
ioconfig:
360+
class: patch_predictor.IOPatchPredictorConfig
361+
kwargs:
362+
patch_input_shape: [96, 96]
363+
stride_shape: [96, 96]
364+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
365+
dataset: pcam
366+
densenet121-pcam:
367+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/densenet121-pcam.pth
368+
architecture:
369+
class: vanilla.CNNModel
370+
kwargs:
371+
backbone: densenet121
372+
num_classes: 2
373+
ioconfig:
374+
class: patch_predictor.IOPatchPredictorConfig
375+
kwargs:
376+
patch_input_shape: [96, 96]
377+
stride_shape: [96, 96]
378+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
379+
dataset: pcam
380+
densenet161-pcam:
381+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/densenet161-pcam.pth
382+
architecture:
383+
class: vanilla.CNNModel
384+
kwargs:
385+
backbone: densenet161
386+
num_classes: 2
387+
ioconfig:
388+
class: patch_predictor.IOPatchPredictorConfig
389+
kwargs:
390+
patch_input_shape: [96, 96]
391+
stride_shape: [96, 96]
392+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
393+
dataset: pcam
394+
densenet169-pcam:
395+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/densenet169-pcam.pth
396+
architecture:
397+
class: vanilla.CNNModel
398+
kwargs:
399+
backbone: densenet169
400+
num_classes: 2
401+
ioconfig:
402+
class: patch_predictor.IOPatchPredictorConfig
403+
kwargs:
404+
patch_input_shape: [96, 96]
405+
stride_shape: [96, 96]
406+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
407+
dataset: pcam
408+
densenet201-pcam:
409+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/densenet201-pcam.pth
410+
architecture:
411+
class: vanilla.CNNModel
412+
kwargs:
413+
backbone: densenet201
414+
num_classes: 2
415+
ioconfig:
416+
class: patch_predictor.IOPatchPredictorConfig
417+
kwargs:
418+
patch_input_shape: [96, 96]
419+
stride_shape: [96, 96]
420+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
421+
dataset: pcam
422+
mobilenet_v2-pcam:
423+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/mobilenet_v2-pcam.pth
424+
architecture:
425+
class: vanilla.CNNModel
426+
kwargs:
427+
backbone: mobilenet_v2
428+
num_classes: 2
429+
ioconfig:
430+
class: patch_predictor.IOPatchPredictorConfig
431+
kwargs:
432+
patch_input_shape: [96, 96]
433+
stride_shape: [96, 96]
434+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
435+
dataset: pcam
436+
mobilenet_v3_large-pcam:
437+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/mobilenet_v3_large-pcam.pth
438+
architecture:
439+
class: vanilla.CNNModel
440+
kwargs:
441+
backbone: mobilenet_v3_large
442+
num_classes: 2
443+
ioconfig:
444+
class: patch_predictor.IOPatchPredictorConfig
445+
kwargs:
446+
patch_input_shape: [96, 96]
447+
stride_shape: [96, 96]
448+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
449+
dataset: pcam
450+
mobilenet_v3_small-pcam:
451+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/mobilenet_v3_small-pcam.pth
452+
architecture:
453+
class: vanilla.CNNModel
454+
kwargs:
455+
backbone: mobilenet_v3_small
456+
num_classes: 2
457+
ioconfig:
458+
class: patch_predictor.IOPatchPredictorConfig
459+
kwargs:
460+
patch_input_shape: [96, 96]
461+
stride_shape: [96, 96]
462+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
463+
dataset: pcam
464+
googlenet-pcam:
465+
url: https://tiatoolbox.dcs.warwick.ac.uk/models/pc/googlenet-pcam.pth
466+
architecture:
467+
class: vanilla.CNNModel
468+
kwargs:
469+
backbone: googlenet
470+
num_classes: 2
471+
ioconfig:
472+
class: patch_predictor.IOPatchPredictorConfig
473+
kwargs:
474+
patch_input_shape: [96, 96]
475+
stride_shape: [96, 96]
476+
input_resolutions: [{"resolution": 1.0, "units": "mpp"}]
477+
dataset: pcam
478+
240479
resnet18-idars-tumour:
241480
url: https://tiatoolbox.dcs.warwick.ac.uk/models/idars/resnet18-idars-tumour.pth
242481
architecture:

tiatoolbox/data/remote_samples.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ files:
5959
url: [*modelroot, "samples/kather_patch1.tif"]
6060
sample-patch-2:
6161
url: [*modelroot, "samples/kather_patch2.tif"]
62+
sample-patch-3:
63+
url: [*modelroot, "samples/pcam_patch1.png"]
64+
sample-patch-4:
65+
url: [*modelroot, "samples/pcam_patch2.png"]
6266
wsi1_2k_2k_svs:
6367
url: [*modelroot, "samples/wsi1_2k_2k.svs"]
6468
wsi1_8k_8k_svs:

0 commit comments

Comments
 (0)