Skip to content

Commit 035ab74

Browse files
committed
🔀 Merge changes from #716 (commit: 32cae0b)
1 parent 0e48f6e commit 035ab74

File tree

3 files changed

+62
-6
lines changed

3 files changed

+62
-6
lines changed

tests/engines/test_patch_predictor.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from typing import TYPE_CHECKING, Callable
1111

1212
import numpy as np
13+
import torch
1314
import zarr
1415
from click.testing import CliRunner
1516

16-
from tiatoolbox import cli
17+
from tests.conftest import timed
18+
from tiatoolbox import cli, logger, rcParam
1719
from tiatoolbox.models import IOPatchPredictorConfig
1820
from tiatoolbox.models.architecture.vanilla import CNNModel
1921
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
@@ -555,6 +557,53 @@ def test_engine_run_wsi_annotation_store(
555557

556558
shutil.rmtree(save_dir)
557559

560+
# ----------------------------------------------------------------------------------
561+
# torch.compile
562+
# ----------------------------------------------------------------------------------
563+
def test_patch_predictor_torch_compile(
564+
sample_patch1: Path,
565+
sample_patch2: Path,
566+
tmp_path: Path,
567+
) -> None:
568+
"""Test PatchPredictor with torch.compile functionality.
569+
570+
Args:
571+
sample_patch1 (Path): Path to sample patch 1.
572+
sample_patch2 (Path): Path to sample patch 2.
573+
tmp_path (Path): Path to temporary directory.
574+
575+
"""
576+
torch_compile_mode = rcParam["torch_compile_mode"]
577+
torch._dynamo.reset()
578+
rcParam["torch_compile_mode"] = "default"
579+
_, compile_time = timed(
580+
test_patch_predictor_api,
581+
sample_patch1,
582+
sample_patch2,
583+
tmp_path,
584+
)
585+
logger.info("torch.compile default mode: %s", compile_time)
586+
torch._dynamo.reset()
587+
rcParam["torch_compile_mode"] = "reduce-overhead"
588+
_, compile_time = timed(
589+
test_patch_predictor_api,
590+
sample_patch1,
591+
sample_patch2,
592+
tmp_path,
593+
)
594+
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
595+
torch._dynamo.reset()
596+
rcParam["torch_compile_mode"] = "max-autotune"
597+
_, compile_time = timed(
598+
test_patch_predictor_api,
599+
sample_patch1,
600+
sample_patch2,
601+
tmp_path,
602+
)
603+
logger.info("torch.compile max-autotune mode: %s", compile_time)
604+
torch._dynamo.reset()
605+
rcParam["torch_compile_mode"] = torch_compile_mode
606+
558607

559608
# -------------------------------------------------------------------------------------
560609
# Command Line Interface

tiatoolbox/models/architecture/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import Callable, NoReturn
6+
from typing import NoReturn
77

88
import numpy as np
99
import torch
@@ -41,7 +41,7 @@ def compile_model(
4141
model: nn.Module | None = None,
4242
*,
4343
mode: str = "default",
44-
) -> Callable:
44+
) -> nn.Module:
4545
"""A decorator to compile a model using torch-compile.
4646
4747
Args:
@@ -60,7 +60,7 @@ def compile_model(
6060
CUDA graphs
6161
6262
Returns:
63-
Callable:
63+
torch.nn.Module:
6464
Compiled model.
6565
6666
"""
@@ -71,7 +71,7 @@ def compile_model(
7171
is_torch_compile_compatible()
7272

7373
# This check will be removed when torch.compile is supported in Python 3.12+
74-
if sys.version_info >= (3, 12): # pragma: no cover
74+
if sys.version_info > (3, 12): # pragma: no cover
7575
logger.warning(
7676
("torch-compile is currently not supported in Python 3.12+. ",),
7777
)

tiatoolbox/models/engine/engine_abc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from torch import nn
1616
from typing_extensions import Unpack
1717

18-
from tiatoolbox import DuplicateFilter, logger
18+
from tiatoolbox import DuplicateFilter, logger, rcParam
1919
from tiatoolbox.models.architecture import get_pretrained_model
20+
from tiatoolbox.models.architecture.utils import compile_model
2021
from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset
2122
from tiatoolbox.models.models_abc import load_torch_model
2223
from tiatoolbox.utils.misc import (
@@ -355,6 +356,12 @@ def __init__(
355356
weights=weights,
356357
)
357358
self.model.to(device=self.device)
359+
self.model = (
360+
compile_model( # for runtime, such as after wrapping with nn.DataParallel
361+
model,
362+
mode=rcParam["torch_compile_mode"],
363+
)
364+
)
358365
self._ioconfig = self.ioconfig # runtime ioconfig
359366

360367
self.batch_size = batch_size

0 commit comments

Comments
 (0)