Skip to content
Merged
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
58dbdc2
⚡️ Add torch.compile decorators
Sep 25, 2023
29fd380
✅ Add simple compute time test
Sep 28, 2023
85d47b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
3b0667d
♻️ refactor test and add disable `torch.compile`
Sep 29, 2023
6ebaec1
Fix conflicts
Sep 29, 2023
168f8ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
c55f00b
💚 Fix CI `no_gpu` error and move timed
Sep 29, 2023
5318a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
b3a8cb8
🎨 Minor improvements
Oct 2, 2023
6112670
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 5, 2023
e97f4a5
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 6, 2023
cdaade2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 6, 2023
dc806c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
f302edb
🔥 Remove `torch.compile` test for now
Oct 6, 2023
d2cf661
🔀 merge changes
Oct 6, 2023
6ab0d8f
⚡️ Add `torch.compile` to SemanticSegmentor
Oct 19, 2023
cd554fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2023
d541058
⚡️ Add `torch.compiled` to PatchPredictor
Oct 20, 2023
b660edc
Merge branch 'enhance-torch-compile'
Oct 20, 2023
3c3305f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
1694d17
🔥 Temp disable `torch.compile` SemanticSegmentor
Oct 20, 2023
f50b083
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 20, 2023
7910039
💡 Remove chanage to `__init__.py`
Oct 20, 2023
5b513e2
🔥 Temp remove `torch.compile` SemanticSegmentor
Oct 20, 2023
aaf076e
🚨 Fix `ruff` linter errors
Oct 20, 2023
2d736c0
🚨 Temp disable cyclomatic complexity check
Oct 27, 2023
65c2c53
🚨 Cont. temp disable cyclomatic complexity check
Oct 27, 2023
9cc0168
🚨 Revert `max-args` back to 10
Nov 3, 2023
d66a7bd
✏️ Add text to notebook
Nov 8, 2023
820c7b9
⏪ Remove unnecessary line in example notebook
Nov 10, 2023
afa81f6
⚡️ Add `torch-compile` to `SemanticSegmentor`
Nov 16, 2023
e5eae50
🚧 Add 'rcParam` as config
Nov 16, 2023
6506799
🚧 Add `torch.compile` mode to `rcParam`
Nov 17, 2023
fc7120e
🐛 Fix argument mishap
Nov 17, 2023
08a3cf8
🚨 Fix linter errors
Nov 17, 2023
a13ee34
Merge branch 'develop' into enhance-torch-compile
Nov 17, 2023
1c0173d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
b0a276e
🐛 Fix `rcParam` definition
Nov 17, 2023
13eafe2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a138bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
32b002e
🐛 Fix `rcParam` definition
Nov 17, 2023
33df9a9
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a003742
🚨 Fix `ruff` lint errors
Nov 23, 2023
df62c5d
Merge branch 'develop' into enhance-torch-compile
shaneahmed Nov 24, 2023
54d4b06
Merge branch 'develop' into enhance-torch-compile
shaneahmed Jan 19, 2024
8bc5328
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
ddfc15c
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 25, 2024
c2c0e89
🚧 Supress `TorchDynamo` errors and disable `torch.compile` by default
Jan 25, 2024
957f847
🚧 Remove a problematic `torch.compile` defintion
Jan 26, 2024
f898bc5
🐛 Fix linter error about importing protected members
Jan 26, 2024
6b1520a
🚨 Further linter error fix
Jan 26, 2024
5b32712
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
6c6ce66
🚨 Further linter error fix
Jan 26, 2024
1574859
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Jan 26, 2024
7be70e5
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 26, 2024
5d8cc6a
🚧 Remove `torch.compile` definitions from main PR
Jan 26, 2024
90d5648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
9377b8d
🐛 Fix missing attribute in WSI registration from previous commit
Jan 26, 2024
b5fd57f
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 2, 2024
8d6b788
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 21, 2024
fb032d4
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 15, 2024
b2f57ee
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 19, 2024
252c7f9
⚡️ Add `torch.compile` to `PatchPredictor` (#776)
Abdol Mar 19, 2024
cf22502
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 22, 2024
1d40585
📝 Fix docstrings
Abdol Mar 22, 2024
7d08c34
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 23, 2024
6ee4353
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 29, 2024
76d8e7e
Merge branch 'develop' into enhance-torch-compile
Abdol May 10, 2024
a767843
Merge branch 'develop' into enhance-torch-compile
Abdol May 14, 2024
e533b85
⚡️Refine `torch.compile` and Add to WSI Registration (#800)
Abdol Jun 14, 2024
7df7c62
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 14, 2024
d97501c
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 24, 2024
5667e54
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 25, 2024
02b8771
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 28, 2024
2729a06
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 9, 2024
92b75df
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 29, 2024
7cf2714
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 9, 2024
14c9409
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 29, 2024
a086195
Merge branch 'develop' into enhance-torch-compile
shaneahmed Sep 19, 2024
8cc2fb4
Merge branch 'develop' into enhance-torch-compile
Abdol Sep 26, 2024
ba1776e
🚧 Add `torch.compile` to SemanticSegmentor
Abdol Sep 26, 2024
dd02dd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
36d3f81
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 2, 2024
6a5cc1d
🐛 Fix `torch.compile` assertion error
Abdol Oct 10, 2024
67fef3f
✅ Add test for SemanticSegmentor with `torch.compile`
Abdol Oct 10, 2024
24bd96b
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 18, 2024
7d1850b
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 20, 2024
fd97f07
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 25, 2024
e5be778
fix DeepSource error
Jiaqi-Lv Nov 1, 2024
62a9009
fix deepsource error
Jiaqi-Lv Nov 1, 2024
2d15229
Apply suggestions from code review
Abdol Nov 4, 2024
8cd748a
Update semantic_segmentor.py as per code review
Abdol Nov 11, 2024
14450f2
Apply suggestions from code review
Abdol Nov 11, 2024
dcfb18a
Merge branch 'develop' into enhance-torch-compile
Abdol Nov 11, 2024
6580a01
try fixing testcov
Jiaqi-Lv Nov 13, 2024
152d698
Adjust spacing in github workflows
shaneahmed Nov 13, 2024
11fc009
Apply suggestions from code review
Abdol Nov 15, 2024
8790efa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
6545920
Update utils.py to address review comments
Abdol Nov 15, 2024
5c6928f
:memo: Add `torch.compile` mode descriptions
Abdol Nov 15, 2024
77830ae
:bug: Fix E501 Line too long
shaneahmed Nov 15, 2024
8f07b0a
Merge branch 'refs/heads/develop' into enhance-torch-compile
shaneahmed Nov 15, 2024
020d9ef
:bug: Fix `test_torch_compile_compatibility`
shaneahmed Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
push:
branches: [ develop, pre-release, master, main ]
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main]

jobs:

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
branches: [ develop, pre-release, master, main ]
tags: v*
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main]

jobs:
build:
Expand Down Expand Up @@ -58,7 +58,7 @@ jobs:
- name: Test with pytest
run: |
pytest --basetemp={envtmpdir} \
--cov=tiatoolbox --cov-report=term --cov-report=xml \
--cov=tiatoolbox --cov-report=term --cov-report=xml --cov-config=pyproject.toml \
--capture=sys \
--durations=10 --durations-min=1.0 \
--maxfail=1
Expand Down
38 changes: 37 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import os
import shutil
import time
from pathlib import Path
from typing import Callable

import pytest
import torch

import tiatoolbox
from tiatoolbox import logger
from tiatoolbox.data import _fetch_remote_sample
from tiatoolbox.utils.env_detection import running_on_ci
from tiatoolbox.utils.env_detection import has_gpu, running_on_ci

# -------------------------------------------------------------------------------------
# Generate Parameterized Tests
Expand Down Expand Up @@ -608,3 +610,37 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]:
(tmp_path / "slides").mkdir()
(tmp_path / "overlays").mkdir()
return {"base_path": tmp_path}


# -------------------------------------------------------------------------------------
# Utility functions
# -------------------------------------------------------------------------------------


def timed(fn: Callable, *args: object) -> (Callable, float):
"""A decorator that times the execution of a function.

Args:
fn (Callable): The function to be timed.
args (object): Arguments to be passed to the function.

Returns:
A tuple containing the result of the function
and the time taken to execute it in seconds.

"""
compile_time = 0.0
if has_gpu():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn(*args)
end.record()
torch.cuda.synchronize()
compile_time = start.elapsed_time(end) / 1000
else:
start = time.time()
result = fn(*args)
end = time.time()
compile_time = end - start
return result, compile_time
8 changes: 7 additions & 1 deletion tests/models/test_nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import joblib
import numpy as np
import pytest
import torch
import yaml
from click.testing import CliRunner

from tiatoolbox import cli
from tiatoolbox import cli, rcParam
from tiatoolbox.models import (
IOSegmentorConfig,
NucleusInstanceSegmentor,
Expand Down Expand Up @@ -44,7 +45,12 @@ def _crash_func(_x: object) -> None:

def helper_tile_info() -> list:
"""Helper function for tile information."""
torch._dynamo.reset()
current_torch_compile_mode = rcParam["torch_compile_mode"]
rcParam["torch_compile_mode"] = "disable"
predictor = NucleusInstanceSegmentor(model="A")
torch._dynamo.reset()
rcParam["torch_compile_mode"] = current_torch_compile_mode
# ! assuming the tiles organized as follows (coming out from
# ! PatchExtractor). If this is broken, need to check back
# ! PatchExtractor output ordering first
Expand Down
53 changes: 52 additions & 1 deletion tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from click.testing import CliRunner

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import (
Expand Down Expand Up @@ -1226,3 +1227,53 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
assert tmp_path.joinpath("2.merged.npy").exists()
assert tmp_path.joinpath("2.raw.json").exists()
assert tmp_path.joinpath("results.json").exists()


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_patch_predictor_torch_compile(
sample_patch1: Path,
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test PatchPredictor with with torch.compile functionality.

Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.

"""
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
48 changes: 47 additions & 1 deletion tests/models/test_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from click.testing import CliRunner
from torch import nn

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import SemanticSegmentor
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import centre_crop
Expand Down Expand Up @@ -897,3 +898,48 @@ def test_cli_semantic_segmentation_multi_file(
_test_pred = (_test_pred[..., 1] > 0.50) * 255

assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_semantic_segmentor_torch_compile(
remote_sample: Callable,
tmp_path: Path,
) -> None:
"""Test SemanticSegmentor using pretrained model with torch.compile functionality.

Args:
remote_sample (Callable): Callable object used to extract remote sample.
tmp_path (Path): Path to temporary directory.

"""
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
40 changes: 39 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import numpy as np
import pandas as pd
import pytest
import torch
from PIL import Image
from requests import HTTPError
from shapely.geometry import Polygon

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import utils
from tiatoolbox import rcParam, utils
from tiatoolbox.annotation.storage import DictionaryStore, SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError
from tiatoolbox.utils.transforms import locsize2bounds
Expand Down Expand Up @@ -1827,3 +1829,39 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.dict_to_store(patch_output, (1.0, 1.0))


def test_torch_compile_already_compiled() -> None:
"""Test that torch_compile does not recompile a model that is already compiled."""
torch_compile_modes = [
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs",
]
current_torch_compile_mode = rcParam["torch_compile_mode"]
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))

for mode in torch_compile_modes:
torch._dynamo.reset()
rcParam["torch_compile_mode"] = mode
compiled_model = compile_model(model, mode=mode)
recompiled_model = compile_model(compiled_model, mode=mode)
assert compiled_model == recompiled_model

torch._dynamo.reset()
rcParam["torch_compile_mode"] = current_torch_compile_mode


def test_torch_compile_disable() -> None:
"""Test torch_compile's disable mode."""
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
compiled_model = compile_model(model, mode="disable")
assert model == compiled_model


def test_torch_compile_compatibility() -> None:
"""Test if torch-compile compatibility is checked correctly."""
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

assert isinstance(is_torch_compile_compatible(), bool)
70 changes: 70 additions & 0 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cv2
import numpy as np
import pytest
import torch

from tests.conftest import timed
from tiatoolbox import logger, rcParam
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
Expand Down Expand Up @@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:
expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)

assert np.sum(expected - output) == 0


def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:
"""Test DFBRFeatureExtractor with torch.compile functionality.

Args:
dfbr_features (Path): Path to the expected features.

"""

def _extract_features() -> tuple:
dfbr = DFBRegister()
fixed_img = np.repeat(
np.expand_dims(
np.repeat(
np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),
64,
axis=1,
),
axis=2,
),
3,
axis=2,
)
output = dfbr.extract_features(fixed_img, fixed_img)
pool3_feat = output["block3_pool"][0, :].detach().numpy()
pool4_feat = output["block4_pool"][0, :].detach().numpy()
pool5_feat = output["block5_pool"][0, :].detach().numpy()

return pool3_feat, pool4_feat, pool5_feat

torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
5 changes: 5 additions & 0 deletions tiatoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class _RcParam(TypedDict):

TIATOOLBOX_HOME: Path
pretrained_model_info: dict[str, dict]
torch_compile_mode: str


def read_registry_files(path_to_registry: str | Path) -> dict:
Expand Down Expand Up @@ -102,6 +103,10 @@ def read_registry_files(path_to_registry: str | Path) -> dict:
"pretrained_model_info": read_registry_files(
"data/pretrained_model.yaml",
), # Load a dictionary of sample files data (names and urls)
"torch_compile_mode": "default",
# Set `torch-compile` mode to `default`
# Options: `disable`, `default`, `reduce-overhead`, `max-autotune`
# or “max-autotune-no-cudagraphs”
}


Expand Down
1 change: 1 addition & 0 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_pretrained_model(
model.load_state_dict(saved_state_dict, strict=True)

# !

io_info = info["ioconfig"]
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")

Expand Down
Loading
Loading