Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import sys
from typing import NoReturn

import numpy as np
import torch
Expand All @@ -12,13 +11,17 @@
from tiatoolbox import logger


def is_torch_compile_compatible() -> NoReturn:
def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.

Returns:
True if current GPU is compatible with torch-compile, False otherwise.

Raises:
Warning if GPU is not compatible with `torch.compile`.

"""
gpu_compatibility = True
if torch.cuda.is_available(): # pragma: no cover
device_cap = torch.cuda.get_device_capability()
if device_cap not in ((7, 0), (8, 0), (9, 0)):
Expand All @@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
else:
logger.warning(
"No GPU detected or cuda not installed, "
"torch.compile is only supported on selected NVIDIA GPUs. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False

return gpu_compatibility


def compile_model(
Expand Down Expand Up @@ -68,12 +75,24 @@ def compile_model(
return model

# Check if GPU is compatible with torch.compile
is_torch_compile_compatible()
gpu_compatibility = is_torch_compile_compatible()

if not gpu_compatibility:
return model

if sys.platform == "win32": # pragma: no cover
msg = (
"`torch.compile` is not supported on Windows. Please see "
"https://github.com/pytorch/pytorch/issues/122094."
)
logger.warning(msg=msg)
return model

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info > (3, 12): # pragma: no cover
msg = "torch-compile is currently not supported in Python 3.12+."
logger.warning(
("torch-compile is currently not supported in Python 3.12+. ",),
msg=msg,
)
return model

Expand Down