Skip to content

Improve program checks #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 21, 2025
15 changes: 15 additions & 0 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import weakref
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from warnings import warn

from cuda.core.experimental._device import Device
from cuda.core.experimental._linker import Linker, LinkerOptions
from cuda.core.experimental._module import ObjectCode
from cuda.core.experimental._utils import (
_handle_boolean_option,
check_or_create_options,
driver,
handle_return,
is_nested_sequence,
is_sequence,
Expand Down Expand Up @@ -378,6 +380,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
raise TypeError("c++ Program expects code argument to be a string")
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved

self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"
self._linker = None
Expand Down Expand Up @@ -414,6 +417,11 @@ def close(self):
self._linker.close()
self._mnff.close()

def _can_load_generated_ptx(self):
driver_ver = handle_return(driver.cuDriverGetVersion())
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver

def compile(self, target_type, name_expressions=(), logs=None):
"""Compile the program with a specific compilation type.

Expand All @@ -440,6 +448,13 @@ def compile(self, target_type, name_expressions=(), logs=None):
raise NotImplementedError

if self._backend == "nvrtc":
if target_type == "ptx" and not self._can_load_generated_ptx():
warn(
"The CUDA driver version is older than the backend version. "
"The generated ptx will not be loadable by the current driver.",
stacklevel=1,
category=RuntimeWarning,
)
if name_expressions:
for n in name_expressions:
handle_return(
Expand Down
9 changes: 1 addition & 8 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
import sys

try:
from cuda.bindings import driver, nvrtc
from cuda.bindings import driver
except ImportError:
from cuda import cuda as driver
from cuda import nvrtc
import pytest

from cuda.core.experimental import Device, _device
Expand Down Expand Up @@ -66,9 +65,3 @@ def clean_up_cffi_files():
os.remove(f)
except FileNotFoundError:
pass # noqa: SIM105


def can_load_generated_ptx():
_, driver_ver = driver.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
12 changes: 9 additions & 3 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
# is strictly prohibited.


import warnings

import pytest
from conftest import can_load_generated_ptx

from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system

Expand Down Expand Up @@ -40,10 +41,15 @@ def get_saxpy_kernel(init_cuda):
return mod.get_kernel("saxpy<float>"), mod


@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
def test_get_kernel(init_cuda):
kernel = """extern "C" __global__ void ABC() { }"""
object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx")

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx")
if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w):
pytest.skip("PTX version too new for current driver")

assert object_code._handle is None
kernel = object_code.get_kernel("ABC")
assert object_code._handle is not None
Expand Down
13 changes: 9 additions & 4 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.

import warnings

import pytest
from conftest import can_load_generated_ptx

from cuda.core.experimental import _linker
from cuda.core.experimental._module import Kernel, ObjectCode
Expand Down Expand Up @@ -100,13 +101,17 @@ def test_program_init_invalid_code_format():
Program(code, "c++")


# TODO: incorporate this check in Program
# This is tested against the current device's arch
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
def test_program_compile_valid_target_type(init_cuda):
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
ptx_object_code = program.compile("ptx")

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
ptx_object_code = program.compile("ptx")
if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w):
pytest.skip("PTX version too new for current driver")

program = Program(ptx_object_code._module.decode(), "ptx")
cubin_object_code = program.compile("cubin")
ptx_kernel = ptx_object_code.get_kernel("my_kernel")
Expand Down
Loading