diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 2ec47d99d..e918a70d8 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import sys -from typing import NoReturn import numpy as np import torch @@ -12,13 +11,17 @@ from tiatoolbox import logger -def is_torch_compile_compatible() -> NoReturn: +def is_torch_compile_compatible() -> bool: """Check if the current GPU is compatible with torch-compile. + Returns: + True if current GPU is compatible with torch-compile, False otherwise. + Raises: Warning if GPU is not compatible with `torch.compile`. """ + gpu_compatibility = True if torch.cuda.is_available(): # pragma: no cover device_cap = torch.cuda.get_device_capability() if device_cap not in ((7, 0), (8, 0), (9, 0)): @@ -28,6 +31,7 @@ def is_torch_compile_compatible() -> NoReturn: "Speedup numbers may be lower than expected.", stacklevel=2, ) + gpu_compatibility = False else: logger.warning( "No GPU detected or cuda not installed, " @@ -35,6 +39,9 @@ def is_torch_compile_compatible() -> NoReturn: "Speedup numbers may be lower than expected.", stacklevel=2, ) + gpu_compatibility = False + + return gpu_compatibility def compile_model( @@ -68,12 +75,24 @@ def compile_model( return model # Check if GPU is compatible with torch.compile - is_torch_compile_compatible() + gpu_compatibility = is_torch_compile_compatible() + + if not gpu_compatibility: + return model + + if sys.platform == "win32": # pragma: no cover + msg = ( + "`torch.compile` is not supported on Windows. Please see " + "https://github.com/pytorch/pytorch/issues/122094." + ) + logger.warning(msg=msg) + return model # This check will be removed when torch.compile is supported in Python 3.12+ if sys.version_info > (3, 12): # pragma: no cover + msg = "torch-compile is currently not supported in Python 3.12+." logger.warning( - ("torch-compile is currently not supported in Python 3.12+. ",), + msg=msg, ) return model