|
10 | 10 | from typing import TYPE_CHECKING, Callable |
11 | 11 |
|
12 | 12 | import numpy as np |
| 13 | +import torch |
13 | 14 | import zarr |
14 | 15 | from click.testing import CliRunner |
15 | 16 |
|
16 | | -from tiatoolbox import cli |
| 17 | +from tests.conftest import timed |
| 18 | +from tiatoolbox import cli, logger, rcParam |
17 | 19 | from tiatoolbox.models import IOPatchPredictorConfig |
18 | 20 | from tiatoolbox.models.architecture.vanilla import CNNModel |
19 | 21 | from tiatoolbox.models.engine.patch_predictor import PatchPredictor |
@@ -555,6 +557,53 @@ def test_engine_run_wsi_annotation_store( |
555 | 557 |
|
556 | 558 | shutil.rmtree(save_dir) |
557 | 559 |
|
| 560 | + # ---------------------------------------------------------------------------------- |
| 561 | + # torch.compile |
| 562 | + # ---------------------------------------------------------------------------------- |
| 563 | + def test_patch_predictor_torch_compile( |
| 564 | + sample_patch1: Path, |
| 565 | + sample_patch2: Path, |
| 566 | + tmp_path: Path, |
| 567 | + ) -> None: |
| 568 | + """Test PatchPredictor with torch.compile functionality. |
| 569 | +
|
| 570 | + Args: |
| 571 | + sample_patch1 (Path): Path to sample patch 1. |
| 572 | + sample_patch2 (Path): Path to sample patch 2. |
| 573 | + tmp_path (Path): Path to temporary directory. |
| 574 | +
|
| 575 | + """ |
| 576 | + torch_compile_mode = rcParam["torch_compile_mode"] |
| 577 | + torch._dynamo.reset() |
| 578 | + rcParam["torch_compile_mode"] = "default" |
| 579 | + _, compile_time = timed( |
| 580 | + test_patch_predictor_api, |
| 581 | + sample_patch1, |
| 582 | + sample_patch2, |
| 583 | + tmp_path, |
| 584 | + ) |
| 585 | + logger.info("torch.compile default mode: %s", compile_time) |
| 586 | + torch._dynamo.reset() |
| 587 | + rcParam["torch_compile_mode"] = "reduce-overhead" |
| 588 | + _, compile_time = timed( |
| 589 | + test_patch_predictor_api, |
| 590 | + sample_patch1, |
| 591 | + sample_patch2, |
| 592 | + tmp_path, |
| 593 | + ) |
| 594 | + logger.info("torch.compile reduce-overhead mode: %s", compile_time) |
| 595 | + torch._dynamo.reset() |
| 596 | + rcParam["torch_compile_mode"] = "max-autotune" |
| 597 | + _, compile_time = timed( |
| 598 | + test_patch_predictor_api, |
| 599 | + sample_patch1, |
| 600 | + sample_patch2, |
| 601 | + tmp_path, |
| 602 | + ) |
| 603 | + logger.info("torch.compile max-autotune mode: %s", compile_time) |
| 604 | + torch._dynamo.reset() |
| 605 | + rcParam["torch_compile_mode"] = torch_compile_mode |
| 606 | + |
558 | 607 |
|
559 | 608 | # ------------------------------------------------------------------------------------- |
560 | 609 | # Command Line Interface |
|
0 commit comments