| 
5 | 5 | import cv2  | 
6 | 6 | import numpy as np  | 
7 | 7 | import pytest  | 
 | 8 | +import torch  | 
8 | 9 | 
 
  | 
 | 10 | +from tests.conftest import timed  | 
 | 11 | +from tiatoolbox import logger, rcParam  | 
9 | 12 | from tiatoolbox.tools.registration.wsi_registration import (  | 
10 | 13 |     AffineWSITransformer,  | 
11 | 14 |     DFBRegister,  | 
@@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:  | 
576 | 579 |         expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)  | 
577 | 580 | 
 
  | 
578 | 581 |         assert np.sum(expected - output) == 0  | 
 | 582 | + | 
 | 583 | + | 
 | 584 | +def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:  | 
 | 585 | +    """Test DFBRFeatureExtractor with torch.compile functionality.  | 
 | 586 | +
  | 
 | 587 | +    Args:  | 
 | 588 | +        dfbr_features (Path): Path to the expected features.  | 
 | 589 | +
  | 
 | 590 | +    """  | 
 | 591 | + | 
 | 592 | +    def _extract_features() -> tuple:  | 
 | 593 | +        dfbr = DFBRegister()  | 
 | 594 | +        fixed_img = np.repeat(  | 
 | 595 | +            np.expand_dims(  | 
 | 596 | +                np.repeat(  | 
 | 597 | +                    np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),  | 
 | 598 | +                    64,  | 
 | 599 | +                    axis=1,  | 
 | 600 | +                ),  | 
 | 601 | +                axis=2,  | 
 | 602 | +            ),  | 
 | 603 | +            3,  | 
 | 604 | +            axis=2,  | 
 | 605 | +        )  | 
 | 606 | +        output = dfbr.extract_features(fixed_img, fixed_img)  | 
 | 607 | +        pool3_feat = output["block3_pool"][0, :].detach().numpy()  | 
 | 608 | +        pool4_feat = output["block4_pool"][0, :].detach().numpy()  | 
 | 609 | +        pool5_feat = output["block5_pool"][0, :].detach().numpy()  | 
 | 610 | + | 
 | 611 | +        return pool3_feat, pool4_feat, pool5_feat  | 
 | 612 | + | 
 | 613 | +    torch_compile_mode = rcParam["torch_compile_mode"]  | 
 | 614 | +    torch._dynamo.reset()  | 
 | 615 | +    rcParam["torch_compile_mode"] = "default"  | 
 | 616 | +    (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)  | 
 | 617 | +    _pool3_feat, _pool4_feat, _pool5_feat = np.load(  | 
 | 618 | +        str(dfbr_features),  | 
 | 619 | +        allow_pickle=True,  | 
 | 620 | +    )  | 
 | 621 | +    assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4  | 
 | 622 | +    assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4  | 
 | 623 | +    assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4  | 
 | 624 | +    logger.info("torch.compile default mode: %s", compile_time)  | 
 | 625 | +    torch._dynamo.reset()  | 
 | 626 | +    rcParam["torch_compile_mode"] = "reduce-overhead"  | 
 | 627 | +    (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)  | 
 | 628 | +    _pool3_feat, _pool4_feat, _pool5_feat = np.load(  | 
 | 629 | +        str(dfbr_features),  | 
 | 630 | +        allow_pickle=True,  | 
 | 631 | +    )  | 
 | 632 | +    assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4  | 
 | 633 | +    assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4  | 
 | 634 | +    assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4  | 
 | 635 | +    logger.info("torch.compile reduce-overhead mode: %s", compile_time)  | 
 | 636 | +    torch._dynamo.reset()  | 
 | 637 | +    rcParam["torch_compile_mode"] = "max-autotune"  | 
 | 638 | +    (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)  | 
 | 639 | +    _pool3_feat, _pool4_feat, _pool5_feat = np.load(  | 
 | 640 | +        str(dfbr_features),  | 
 | 641 | +        allow_pickle=True,  | 
 | 642 | +    )  | 
 | 643 | +    assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4  | 
 | 644 | +    assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4  | 
 | 645 | +    assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4  | 
 | 646 | +    logger.info("torch.compile max-autotune mode: %s", compile_time)  | 
 | 647 | +    torch._dynamo.reset()  | 
 | 648 | +    rcParam["torch_compile_mode"] = torch_compile_mode  | 
0 commit comments