diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index b1fb0d90f..1bb48afc9 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -5,6 +5,7 @@ import weakref from dataclasses import dataclass from typing import List, Optional, Tuple, Union +from warnings import warn from cuda.core.experimental._device import Device from cuda.core.experimental._linker import Linker, LinkerOptions @@ -12,6 +13,7 @@ from cuda.core.experimental._utils import ( _handle_boolean_option, check_or_create_options, + driver, handle_return, is_nested_sequence, is_sequence, @@ -378,6 +380,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None): raise TypeError("c++ Program expects code argument to be a string") # TODO: support pre-loaded headers & include names # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) self._backend = "nvrtc" self._linker = None @@ -414,6 +417,11 @@ def close(self): self._linker.close() self._mnff.close() + def _can_load_generated_ptx(self): + driver_ver = handle_return(driver.cuDriverGetVersion()) + nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) + return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver + def compile(self, target_type, name_expressions=(), logs=None): """Compile the program with a specific compilation type. @@ -440,6 +448,13 @@ def compile(self, target_type, name_expressions=(), logs=None): raise NotImplementedError if self._backend == "nvrtc": + if target_type == "ptx" and not self._can_load_generated_ptx(): + warn( + "The CUDA driver version is older than the backend version. " + "The generated ptx will not be loadable by the current driver.", + stacklevel=1, + category=RuntimeWarning, + ) if name_expressions: for n in name_expressions: handle_return( diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index dc50585ab..44510ea15 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -11,10 +11,9 @@ import sys try: - from cuda.bindings import driver, nvrtc + from cuda.bindings import driver except ImportError: from cuda import cuda as driver - from cuda import nvrtc import pytest from cuda.core.experimental import Device, _device @@ -66,9 +65,3 @@ def clean_up_cffi_files(): os.remove(f) except FileNotFoundError: pass # noqa: SIM105 - - -def can_load_generated_ptx(): - _, driver_ver = driver.cuDriverGetVersion() - _, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion() - return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index f859142c9..355a0c49a 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -7,8 +7,9 @@ # is strictly prohibited. +import warnings + import pytest -from conftest import can_load_generated_ptx from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system @@ -40,10 +41,15 @@ def get_saxpy_kernel(init_cuda): return mod.get_kernel("saxpy"), mod -@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new") def test_get_kernel(init_cuda): kernel = """extern "C" __global__ void ABC() { }""" - object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx") + if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w): + pytest.skip("PTX version too new for current driver") + assert object_code._handle is None kernel = object_code.get_kernel("ABC") assert object_code._handle is not None diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index d9c5cbde1..8d2ecd1ab 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -6,8 +6,9 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. +import warnings + import pytest -from conftest import can_load_generated_ptx from cuda.core.experimental import _linker from cuda.core.experimental._module import Kernel, ObjectCode @@ -100,13 +101,17 @@ def test_program_init_invalid_code_format(): Program(code, "c++") -# TODO: incorporate this check in Program # This is tested against the current device's arch -@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new") def test_program_compile_valid_target_type(init_cuda): code = 'extern "C" __global__ void my_kernel() {}' program = Program(code, "c++") - ptx_object_code = program.compile("ptx") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ptx_object_code = program.compile("ptx") + if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w): + pytest.skip("PTX version too new for current driver") + program = Program(ptx_object_code._module.decode(), "ptx") cubin_object_code = program.compile("cubin") ptx_kernel = ptx_object_code.get_kernel("my_kernel")