Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
dc3222d
nvvm ir integration
abhilash1910 Aug 26, 2025
028a294
add test
abhilash1910 Aug 27, 2025
a418f4b
remove nvvm error handling from utils
abhilash1910 Sep 1, 2025
92af3dd
use version dependent nvvm inclusion
abhilash1910 Sep 1, 2025
bdd1671
fix nvvm compilation flow and test
abhilash1910 Sep 1, 2025
f4151bb
Merge branch 'main' into nvvm
abhilash1910 Sep 1, 2025
dc9a4e3
refactor
abhilash1910 Sep 3, 2025
22d18b9
fix unwanted rebase
abhilash1910 Sep 3, 2025
0f7fda4
fix core linter errors
abhilash1910 Sep 3, 2025
9ed8051
refactor tests
abhilash1910 Sep 3, 2025
bccda47
refactor
abhilash1910 Sep 3, 2025
64436c1
refactor
abhilash1910 Sep 3, 2025
58317d0
ruff format
abhilash1910 Sep 4, 2025
caf4f22
ruff format
abhilash1910 Sep 4, 2025
88237bc
revert changes to cuda_utils
abhilash1910 Sep 4, 2025
5e2e137
new line
abhilash1910 Sep 4, 2025
0301a4d
fix CI rm list import
abhilash1910 Sep 5, 2025
28e2d4b
use noqa
abhilash1910 Sep 7, 2025
af1008f
format
abhilash1910 Sep 7, 2025
a85a44f
verify and skip 110
abhilash1910 Sep 8, 2025
0be06aa
add flags and lto
abhilash1910 Sep 8, 2025
1c63a11
rename gpu-arch to arch
abhilash1910 Sep 15, 2025
cab6db0
change libnvvm version check
abhilash1910 Sep 15, 2025
3abcd38
format
abhilash1910 Sep 15, 2025
caf634c
compute 90
abhilash1910 Sep 15, 2025
64d19ab
Apply suggestions from code review
leofang Sep 15, 2025
d5d216c
Merge branch 'main' into nvvm
leofang Sep 15, 2025
c5993dc
update test
abhilash1910 Sep 16, 2025
5d5b1d3
use exception manager
abhilash1910 Sep 16, 2025
f6b5528
format
abhilash1910 Sep 16, 2025
c7fad0a
format ruff
abhilash1910 Sep 16, 2025
2e6e02b
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 16, 2025
680d790
add release notes
abhilash1910 Sep 16, 2025
c55fa59
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 16, 2025
2cbee7f
rectify quotes
abhilash1910 Sep 16, 2025
63e8d57
refix format
abhilash1910 Sep 16, 2025
94c2e56
refresh
abhilash1910 Sep 17, 2025
34bf2cc
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 17, 2025
6b130bb
user major minor
abhilash1910 Sep 17, 2025
d96c848
fix test
leofang Sep 17, 2025
fcd7c0c
Merge branch 'main' into nvvm
leofang Sep 17, 2025
8331ecf
fix IR - again
leofang Sep 17, 2025
2fa944e
fix nvvm option handling
leofang Sep 17, 2025
4d32276
remove redundant IR & fix linter
leofang Sep 17, 2025
e5b5ea4
avoid extra copy + ensure compiled objcode loadable
leofang Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 119 additions & 6 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import TYPE_CHECKING, List, Tuple, Union
from warnings import warn

import list

if TYPE_CHECKING:
import cuda.bindings

Expand All @@ -20,12 +22,60 @@
_handle_boolean_option,
check_or_create_options,
driver,
get_binding_version,
handle_return,
is_nested_sequence,
is_sequence,
nvrtc,
)

_nvvm_module = None
_nvvm_import_attempted = False


def _get_nvvm_module():
"""
Handles the import of NVVM module with version and availability checks.
NVVM bindings were added in CUDA 12.9.0, so we need to handle cases where:
1. cuda.bindings is not new enough (< 12.9.0)
2. libnvvm is not found in the Python environment

Returns:
The nvvm module if available and working

Raises:
ImportError: If NVVM is not available due to version or library issues
"""
global _nvvm_module, _nvvm_import_attempted

if _nvvm_import_attempted:
if _nvvm_module is None:
raise ImportError("NVVM module is not available (previous import attempt failed)")
return _nvvm_module

_nvvm_import_attempted = True

try:
version = get_binding_version()
if version < (12, 9):
raise ImportError(
f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. "
"Please update cuda-bindings to use NVVM features."
)

from cuda.bindings import nvvm
from cuda.bindings._internal.nvvm import _inspect_function_pointer

if _inspect_function_pointer("__nvvmCreateProgram") == 0:
raise ImportError("NVVM library (libnvvm) is not available in this Python environment. ")

_nvvm_module = nvvm
return _nvvm_module

except ImportError as e:
_nvvm_module = None
raise e


def _process_define_macro_inner(formatted_options, macro):
if isinstance(macro, str):
Expand Down Expand Up @@ -370,28 +420,33 @@ class Program:
code : Any
String of the CUDA Runtime Compilation program.
code_type : Any
String of the code type. Currently ``"ptx"`` and ``"c++"`` are supported.
String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm"`` are supported.
options : ProgramOptions, optional
A ProgramOptions object to customize the compilation process.
See :obj:`ProgramOptions` for more information.
"""

class _MembersNeededForFinalize:
__slots__ = "handle"
__slots__ = "handle", "backend"

def __init__(self, program_obj, handle):
def __init__(self, program_obj, handle, backend):
self.handle = handle
self.backend = backend
weakref.finalize(program_obj, self.close)

def close(self):
if self.handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
if self.backend == "NVRTC":
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
elif self.backend == "NVVM":
nvvm = _get_nvvm_module()
nvvm.destroy_program(self.handle)
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options")

def __init__(self, code, code_type, options: ProgramOptions = None):
self._mnff = Program._MembersNeededForFinalize(self, None)
self._mnff = Program._MembersNeededForFinalize(self, None, None)

self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
code_type = code_type.lower()
Expand All @@ -402,6 +457,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved

self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], []))
self._mnff.backend = "NVRTC"
self._backend = "NVRTC"
self._linker = None

Expand All @@ -411,8 +467,22 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
)
self._backend = self._linker.backend

elif code_type == "nvvm":
if isinstance(code, str):
code = code.encode("utf-8")
elif not isinstance(code, (bytes, bytearray)):
raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray")

nvvm = _get_nvvm_module()
self._mnff.handle = nvvm.create_program()
self._mnff.backend = "NVVM"
nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode())
self._backend = "NVVM"
self._linker = None

else:
supported_code_types = ("c++", "ptx")
supported_code_types = ("c++", "ptx", "nvvm")
assert code_type not in supported_code_types, f"{code_type=}"
raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})")

Expand All @@ -433,6 +503,27 @@ def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions:
ptxas_options=options.ptxas_options,
)

def _translate_program_options_to_nvvm(self, options: ProgramOptions) -> list[str]:
"""Translate ProgramOptions to NVVM-specific compilation options."""
nvvm_options = []

if options.arch is not None:
arch = options.arch
if arch.startswith("sm_"):
arch = f"compute_{arch[3:]}"
nvvm_options.append(f"-arch={arch}")
else:
major, minor = Device().compute_capability
nvvm_options.append(f"-arch=compute_{major}{minor}")
if options.debug:
nvvm_options.append("-g")
if options.device_code_optimize is False:
nvvm_options.append("-opt=0")
elif options.device_code_optimize is True:
nvvm_options.append("-opt=3")

return nvvm_options

def close(self):
"""Destroy this program."""
if self._linker:
Expand Down Expand Up @@ -513,6 +604,28 @@ def compile(self, target_type, name_expressions=(), logs=None):

return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name)

elif self._backend == "NVVM":
if target_type != "ptx":
raise ValueError(f'NVVM backend only supports target_type="ptx", got "{target_type}"')

nvvm_options = self._translate_program_options_to_nvvm(self._options)
nvvm = _get_nvvm_module()
nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options)

size = nvvm.get_compiled_result_size(self._mnff.handle)
data = bytearray(size)
nvvm.get_compiled_result(self._mnff.handle, data)

if logs is not None:
logsize = nvvm.get_program_log_size(self._mnff.handle)
if logsize > 1:
log = bytearray(logsize)
nvvm.get_program_log(self._mnff.handle, log)
logs.write(log.decode("utf-8", errors="backslashreplace"))

data_bytes = bytes(data)
return ObjectCode._init(data_bytes, target_type, name=self._options.name)

supported_backends = ("nvJitLink", "driver")
if self._backend not in supported_backends:
raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})')
Expand Down
154 changes: 153 additions & 1 deletion cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,106 @@
is_culink_backend = _linker._decide_nvjitlink_or_driver()


def _is_nvvm_available():
"""Check if NVVM is available."""
try:
from cuda.core.experimental._program import _get_nvvm_module

_get_nvvm_module()
return True
except ImportError:
return False


nvvm_available = pytest.mark.skipif(
not _is_nvvm_available(), reason="NVVM not available (libNVVM not found or cuda-bindings < 12.9.0)"
)


@pytest.fixture(scope="session")
def nvvm_ir():
"""Generate working NVVM IR with proper version metadata.
The try clause here is used for older nvvm modules which
might not have an ir_version() method. In which case the
fallback assumes no version metadata will be present in
the input nvvm ir
"""
try:
from cuda.core.experimental._program import _get_nvvm_module

nvvm = _get_nvvm_module()
major, minor, debug_major, debug_minor = nvvm.ir_version()

nvvm_ir_template = """target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-" \
"f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @ave(i32 %a, i32 %b) {{
entry:
%add = add nsw i32 %a, %b
%div = sdiv i32 %add, 2
ret i32 %div
}}

define void @simple(i32* %data) {{
entry:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%1 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%mul = mul i32 %0, %1
%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%add = add i32 %mul, %2
%call = call i32 @ave(i32 %add, i32 %add)
%idxprom = sext i32 %add to i64
store i32 %call, i32* %data, align 4
ret void
}}

declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone

!nvvm.annotations = !{{!0}}
!0 = !{{void (i32*)* @simple, !"kernel", i32 1}}

!nvvmir.version = !{{!1}}
!1 = !{{i32 {major}, i32 0, i32 {debug_major}, i32 0}}
"""

return nvvm_ir_template.format(major=major, debug_major=debug_major)
except Exception:
return """target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-" \
"f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @ave(i32 %a, i32 %b) {
entry:
%add = add nsw i32 %a, %b
%div = sdiv i32 %add, 2
ret i32 %div
}

define void @simple(i32* %data) {
entry:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%1 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%mul = mul i32 %0, %1
%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%add = add i32 %mul, %2
%call = call i32 @ave(i32 %add, i32 %add)
%idxprom = sext i32 %add to i64
store i32 %call, i32* %data, align 4
ret void
}

declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone

!nvvm.annotations = !{!0}
!0 = !{void (i32*)* @simple, !"kernel", i32 1}
"""


@pytest.fixture(scope="module")
def ptx_code_object():
code = 'extern "C" __global__ void my_kernel() {}'
Expand Down Expand Up @@ -92,7 +192,7 @@ def test_program_init_valid_code_type():
def test_program_init_invalid_code_type():
code = "goto 100"
with pytest.raises(
RuntimeError, match=r"^Unsupported code_type='fortran' \(supported_code_types=\('c\+\+', 'ptx'\)\)$"
RuntimeError, match=r"^Unsupported code_type='fortran' \(supported_code_types=\('c\+\+', 'ptx', 'nvvm'\)\)$"
):
Program(code, "FORTRAN")

Expand Down Expand Up @@ -150,3 +250,55 @@ def test_program_close():
program = Program(code, "c++")
program.close()
assert program.handle is None


@nvvm_available
def test_nvvm_deferred_import():
"""Test that our deferred NVVM import works correctly"""
from cuda.core.experimental._program import _get_nvvm_module

nvvm = _get_nvvm_module()
assert nvvm is not None


@nvvm_available
def test_nvvm_program_creation(nvvm_ir):
"""Test basic NVVM program creation"""
program = Program(nvvm_ir, "nvvm")
assert program.backend == "NVVM"
assert program.handle is not None
program.close()


@nvvm_available
def test_nvvm_compile_invalid_target(nvvm_ir):
"""Test that NVVM programs reject invalid compilation targets"""
program = Program(nvvm_ir, "nvvm")
with pytest.raises(ValueError, match='NVVM backend only supports target_type="ptx"'):
program.compile("cubin")
program.close()


@nvvm_available
@pytest.mark.parametrize(
"options",
[
ProgramOptions(name="test1", arch="sm_90", device_code_optimize=False),
ProgramOptions(name="test2", arch="sm_100", device_code_optimize=False),
ProgramOptions(name="test3", arch="sm_110", device_code_optimize=True),
],
)
def test_nvvm_program_options(init_cuda, nvvm_ir, options):
"""Test NVVM programs with different options"""
program = Program(nvvm_ir, "nvvm", options)
assert program.backend == "NVVM"

ptx_code = program.compile("ptx")
assert isinstance(ptx_code, ObjectCode)
assert ptx_code.name == options.name

code_content = ptx_code.code
ptx_text = code_content.decode() if isinstance(code_content, bytes) else str(code_content)
assert ".visible .entry simple(" in ptx_text

program.close()
Loading