Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
dc3222d
nvvm ir integration
abhilash1910 Aug 26, 2025
028a294
add test
abhilash1910 Aug 27, 2025
a418f4b
remove nvvm error handling from utils
abhilash1910 Sep 1, 2025
92af3dd
use version dependent nvvm inclusion
abhilash1910 Sep 1, 2025
bdd1671
fix nvvm compilation flow and test
abhilash1910 Sep 1, 2025
f4151bb
Merge branch 'main' into nvvm
abhilash1910 Sep 1, 2025
dc9a4e3
refactor
abhilash1910 Sep 3, 2025
22d18b9
fix unwanted rebase
abhilash1910 Sep 3, 2025
0f7fda4
fix core linter errors
abhilash1910 Sep 3, 2025
9ed8051
refactor tests
abhilash1910 Sep 3, 2025
bccda47
refactor
abhilash1910 Sep 3, 2025
64436c1
refactor
abhilash1910 Sep 3, 2025
58317d0
ruff format
abhilash1910 Sep 4, 2025
caf4f22
ruff format
abhilash1910 Sep 4, 2025
88237bc
revert changes to cuda_utils
abhilash1910 Sep 4, 2025
5e2e137
new line
abhilash1910 Sep 4, 2025
0301a4d
fix CI rm list import
abhilash1910 Sep 5, 2025
28e2d4b
use noqa
abhilash1910 Sep 7, 2025
af1008f
format
abhilash1910 Sep 7, 2025
a85a44f
verify and skip 110
abhilash1910 Sep 8, 2025
0be06aa
add flags and lto
abhilash1910 Sep 8, 2025
1c63a11
rename gpu-arch to arch
abhilash1910 Sep 15, 2025
cab6db0
change libnvvm version check
abhilash1910 Sep 15, 2025
3abcd38
format
abhilash1910 Sep 15, 2025
caf634c
compute 90
abhilash1910 Sep 15, 2025
64d19ab
Apply suggestions from code review
leofang Sep 15, 2025
d5d216c
Merge branch 'main' into nvvm
leofang Sep 15, 2025
c5993dc
update test
abhilash1910 Sep 16, 2025
5d5b1d3
use exception manager
abhilash1910 Sep 16, 2025
f6b5528
format
abhilash1910 Sep 16, 2025
c7fad0a
format ruff
abhilash1910 Sep 16, 2025
2e6e02b
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 16, 2025
680d790
add release notes
abhilash1910 Sep 16, 2025
c55fa59
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 16, 2025
2cbee7f
rectify quotes
abhilash1910 Sep 16, 2025
63e8d57
refix format
abhilash1910 Sep 16, 2025
94c2e56
refresh
abhilash1910 Sep 17, 2025
34bf2cc
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Sep 17, 2025
6b130bb
user major minor
abhilash1910 Sep 17, 2025
d96c848
fix test
leofang Sep 17, 2025
fcd7c0c
Merge branch 'main' into nvvm
leofang Sep 17, 2025
8331ecf
fix IR - again
leofang Sep 17, 2025
2fa944e
fix nvvm option handling
leofang Sep 17, 2025
4d32276
remove redundant IR & fix linter
leofang Sep 17, 2025
e5b5ea4
avoid extra copy + ensure compiled objcode loadable
leofang Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,24 @@ def from_ltoir(module: Union[bytes, str], *, name: str = "", symbol_mapping: Opt
them (default to no mappings).
"""
return ObjectCode._init(module, "ltoir", name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_nvvm(module: Union[bytes, str], *, name: str = "", symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing NVVM IR.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory NVVM IR code to load, or
a file path string pointing to the on-disk NVVM IR file to load.
name : Optional[str]
A human-readable identifier representing this code object.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "nvvm", name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_fatbin(
Expand Down
65 changes: 60 additions & 5 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_nested_sequence,
is_sequence,
nvrtc,
nvvm,
)


Expand Down Expand Up @@ -370,22 +371,26 @@ class Program:
code : Any
String of the CUDA Runtime Compilation program.
code_type : Any
String of the code type. Currently ``"ptx"`` and ``"c++"`` are supported.
String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm"`` are supported.
options : ProgramOptions, optional
A ProgramOptions object to customize the compilation process.
See :obj:`ProgramOptions` for more information.
"""

class _MembersNeededForFinalize:
__slots__ = "handle"
__slots__ = "handle", "backend"

def __init__(self, program_obj, handle):
def __init__(self, program_obj, handle, backend="NVRTC"):
self.handle = handle
self.backend = backend
weakref.finalize(program_obj, self.close)

def close(self):
if self.handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
if self.backend == "NVRTC":
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
elif self.backend == "NVVM":
handle_return(nvvm.destroy_program(self.handle))
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options")
Expand All @@ -402,6 +407,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved

self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], []))
self._mnff.backend = "NVRTC"
self._backend = "NVRTC"
self._linker = None

Expand All @@ -411,8 +417,21 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
)
self._backend = self._linker.backend

elif code_type == "nvvm":
if isinstance(code, str):
code = code.encode('utf-8')
elif not isinstance(code, (bytes, bytearray)):
raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray")

self._mnff.handle = nvvm.create_program()
self._mnff.backend = "NVVM"
nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode())
self._backend = "NVVM"
self._linker = None

else:
supported_code_types = ("c++", "ptx")
supported_code_types = ("c++", "ptx", "nvvm")
assert code_type not in supported_code_types, f"{code_type=}"
raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})")

Expand Down Expand Up @@ -513,6 +532,42 @@ def compile(self, target_type, name_expressions=(), logs=None):

return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name)

elif self._backend == "NVVM":
if target_type != "ptx":
raise ValueError(f'NVVM backend only supports target_type="ptx", got "{target_type}"')

nvvm_options = []
if self._options.arch is not None:
arch = self._options.arch
if arch.startswith("sm_"):
arch = f"compute_{arch[3:]}"
nvvm_options.append(f"-arch={arch}")
else:
major, minor = Device().compute_capability
nvvm_options.append(f"-arch=compute_{major}{minor}")

if self._options.debug:
nvvm_options.append("-g")
if self._options.device_code_optimize is False:
nvvm_options.append("-opt=0")
elif self._options.device_code_optimize is True:
nvvm_options.append("-opt=3")

nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options)

size = nvvm.get_compiled_result_size(self._mnff.handle)
data = bytearray(size)
nvvm.get_compiled_result(self._mnff.handle, data)

if logs is not None:
logsize = nvvm.get_program_log_size(self._mnff.handle)
if logsize > 1:
log = bytearray(logsize)
nvvm.get_program_log(self._mnff.handle, log)
logs.write(log.decode("utf-8", errors="backslashreplace"))

return ObjectCode._init(data, target_type, name=self._options.name)

supported_backends = ("nvJitLink", "driver")
if self._backend not in supported_backends:
raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})')
Expand Down
27 changes: 25 additions & 2 deletions cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ from collections.abc import Sequence
from typing import Callable

try:
from cuda.bindings import driver, nvrtc, runtime
from cuda.bindings import driver, nvrtc, nvvm, runtime
except ImportError:
from cuda import cuda as driver
from cuda import cudart as runtime
from cuda import nvrtc
from cuda import nvrtc, nvvm

from cuda.core.experimental._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS
from cuda.core.experimental._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS
Expand All @@ -27,6 +27,10 @@ class NVRTCError(CUDAError):
pass


class NVVMError(CUDAError):
pass


ComputeCapability = namedtuple("ComputeCapability", ("major", "minor"))


Expand Down Expand Up @@ -55,6 +59,7 @@ def _reduce_3_tuple(t: tuple):
cdef object _DRIVER_SUCCESS = driver.CUresult.CUDA_SUCCESS
cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS
cdef object _NVVM_SUCCESS = nvvm.Result.SUCCESS


cpdef inline int _check_driver_error(error) except?-1:
Expand Down Expand Up @@ -103,13 +108,31 @@ cpdef inline int _check_nvrtc_error(error, handle=None) except?-1:
raise NVRTCError(err)


cpdef inline int _check_nvvm_error(error, handle=None) except?-1:
if error == _NVVM_SUCCESS:
return 0
err = f"{error}: {nvvm.get_error_string(error)}"
if handle is not None:
try:
logsize = nvvm.get_program_log_size(handle)
if logsize > 1:
log = bytearray(logsize)
nvvm.get_program_log(handle, log)
err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}"
except Exception as e:
raise NVVMError(err) from e
raise NVVMError(err)


cdef inline int _check_error(error, handle=None) except?-1:
if isinstance(error, driver.CUresult):
return _check_driver_error(error)
elif isinstance(error, runtime.cudaError_t):
return _check_runtime_error(error)
elif isinstance(error, nvrtc.nvrtcResult):
return _check_nvrtc_error(error, handle=handle)
elif isinstance(error, nvvm.Result):
return _check_nvvm_error(error, handle=handle)
else:
raise RuntimeError(f"Unknown error type: {error}")

Expand Down
72 changes: 71 additions & 1 deletion cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@

is_culink_backend = _linker._decide_nvjitlink_or_driver()

nvvm_ir = """
target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @ave(i32 %a, i32 %b) {
entry:
%add = add nsw i32 %a, %b
%div = sdiv i32 %add, 2
ret i32 %div
}

define void @simple(i32* %data) {
entry:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%1 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%mul = mul i32 %0, %1
%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%add = add i32 %mul, %2
%call = call i32 @ave(i32 %add, i32 %add)
%idxprom = sext i32 %add to i64
store i32 %call, i32* %data, align 4
ret void
}

declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() nounwind readnone

declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() nounwind readnone

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone

"""

@pytest.fixture(scope="module")
def ptx_code_object():
Expand Down Expand Up @@ -92,7 +123,7 @@ def test_program_init_valid_code_type():
def test_program_init_invalid_code_type():
code = "goto 100"
with pytest.raises(
RuntimeError, match=r"^Unsupported code_type='fortran' \(supported_code_types=\('c\+\+', 'ptx'\)\)$"
RuntimeError, match=r"^Unsupported code_type='fortran' \(supported_code_types=\('c\+\+', 'ptx', 'nvvm'\)\)$"
):
Program(code, "FORTRAN")

Expand Down Expand Up @@ -150,3 +181,42 @@ def test_program_close():
program = Program(code, "c++")
program.close()
assert program.handle is None

nvvm_options = [
ProgramOptions(name="nvvm_test"),
ProgramOptions(device_code_optimize=True),
ProgramOptions(arch="sm_90"),
ProgramOptions(debug=True),
]

@pytest.mark.parametrize("options", nvvm_options)
def test_nvvm_program_with_various_options(init_cuda, options):
program = Program(nvvm_ir, "nvvm", options)
assert program.backend == "NVVM"
program.compile("ptx")
program.close()
assert program.handle is None


def test_nvvm_program_creation():
program = Program(nvvm_ir, "nvvm")
assert program.backend == "NVVM"
assert program.handle is not None


def test_nvvm_compile_invalid_target():
program = Program(nvvm_ir, "nvvm")
with pytest.raises(ValueError):
program.compile("cubin")


def test_nvvm_compile_valid_target_type(init_cuda):
program = Program(nvvm_ir, "nvvm", options={"name": "nvvm_test"})
ptx_object_code = program.compile("ptx")
assert isinstance(ptx_object_code, ObjectCode)
assert ptx_object_code.name == "nvvm_test"

ptx_kernel = ptx_object_code.get_kernel("nvvm_kernel")
assert isinstance(ptx_kernel, Kernel)

program.close()