Skip to content

Commit 8cc1b68

Browse files
authored
Add get_cuda_native_handle (#773)
* add get_cuda_native_handle * ensure the path variable is from the same scope * address review comment + add missing API ref * fix linter error
1 parent 4c21824 commit 8cc1b68

File tree

9 files changed

+298
-4
lines changed

9 files changed

+298
-4
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ cuda_bindings/cuda/bindings/runtime.pxd
4848
cuda_bindings/cuda/bindings/runtime.pyx
4949
cuda_bindings/cuda/bindings/nvrtc.pxd
5050
cuda_bindings/cuda/bindings/nvrtc.pyx
51+
cuda_bindings/cuda/bindings/utils/_get_handle.pyx
5152

5253
# Distribution / packaging
5354
.Python
@@ -181,4 +182,4 @@ dmypy.json
181182
cython_debug/
182183

183184
# Dont ignore
184-
!.github/actions/build/
185+
!.github/actions/build/

cuda_bindings/cuda/bindings/utils/__init__.pxd

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4+
from ._get_handle import get_cuda_native_handle
45
from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
# This code was automatically generated with version 12.9.0. Do not modify it directly.
5+
6+
from libc.stdint cimport uintptr_t
7+
cimport cython
8+
9+
from cuda.bindings cimport driver, runtime, cydriver, cyruntime
10+
11+
12+
cdef dict _handle_getters = None
13+
14+
@cython.embedsignature(True)
15+
def get_cuda_native_handle(obj) -> int:
16+
""" Returns the address of the provided CUDA Python object as Python int.
17+
18+
Parameters
19+
----------
20+
obj : Any
21+
CUDA Python object
22+
23+
Returns
24+
-------
25+
int : The object address.
26+
"""
27+
global _handle_getters
28+
obj_type = type(obj)
29+
if _handle_getters is None:
30+
_handle_getters = dict()
31+
{{if 'CUcontext' in found_types}}
32+
def CUcontext_getter(driver.CUcontext x): return <uintptr_t><void*><cydriver.CUcontext>(x._pvt_ptr[0])
33+
_handle_getters[driver.CUcontext] = CUcontext_getter
34+
{{endif}}
35+
{{if 'CUmodule' in found_types}}
36+
def CUmodule_getter(driver.CUmodule x): return <uintptr_t><void*><cydriver.CUmodule>(x._pvt_ptr[0])
37+
_handle_getters[driver.CUmodule] = CUmodule_getter
38+
{{endif}}
39+
{{if 'CUfunction' in found_types}}
40+
def CUfunction_getter(driver.CUfunction x): return <uintptr_t><void*><cydriver.CUfunction>(x._pvt_ptr[0])
41+
_handle_getters[driver.CUfunction] = CUfunction_getter
42+
{{endif}}
43+
{{if 'CUlibrary' in found_types}}
44+
def CUlibrary_getter(driver.CUlibrary x): return <uintptr_t><void*><cydriver.CUlibrary>(x._pvt_ptr[0])
45+
_handle_getters[driver.CUlibrary] = CUlibrary_getter
46+
{{endif}}
47+
{{if 'CUkernel' in found_types}}
48+
def CUkernel_getter(driver.CUkernel x): return <uintptr_t><void*><cydriver.CUkernel>(x._pvt_ptr[0])
49+
_handle_getters[driver.CUkernel] = CUkernel_getter
50+
{{endif}}
51+
{{if 'CUarray' in found_types}}
52+
def CUarray_getter(driver.CUarray x): return <uintptr_t><void*><cydriver.CUarray>(x._pvt_ptr[0])
53+
_handle_getters[driver.CUarray] = CUarray_getter
54+
{{endif}}
55+
{{if 'CUmipmappedArray' in found_types}}
56+
def CUmipmappedArray_getter(driver.CUmipmappedArray x): return <uintptr_t><void*><cydriver.CUmipmappedArray>(x._pvt_ptr[0])
57+
_handle_getters[driver.CUmipmappedArray] = CUmipmappedArray_getter
58+
{{endif}}
59+
{{if 'CUtexref' in found_types}}
60+
def CUtexref_getter(driver.CUtexref x): return <uintptr_t><void*><cydriver.CUtexref>(x._pvt_ptr[0])
61+
_handle_getters[driver.CUtexref] = CUtexref_getter
62+
{{endif}}
63+
{{if 'CUsurfref' in found_types}}
64+
def CUsurfref_getter(driver.CUsurfref x): return <uintptr_t><void*><cydriver.CUsurfref>(x._pvt_ptr[0])
65+
_handle_getters[driver.CUsurfref] = CUsurfref_getter
66+
{{endif}}
67+
{{if 'CUevent' in found_types}}
68+
def CUevent_getter(driver.CUevent x): return <uintptr_t><void*><cydriver.CUevent>(x._pvt_ptr[0])
69+
_handle_getters[driver.CUevent] = CUevent_getter
70+
{{endif}}
71+
{{if 'CUstream' in found_types}}
72+
def CUstream_getter(driver.CUstream x): return <uintptr_t><void*><cydriver.CUstream>(x._pvt_ptr[0])
73+
_handle_getters[driver.CUstream] = CUstream_getter
74+
{{endif}}
75+
{{if 'CUgraphicsResource' in found_types}}
76+
def CUgraphicsResource_getter(driver.CUgraphicsResource x): return <uintptr_t><void*><cydriver.CUgraphicsResource>(x._pvt_ptr[0])
77+
_handle_getters[driver.CUgraphicsResource] = CUgraphicsResource_getter
78+
{{endif}}
79+
{{if 'CUexternalMemory' in found_types}}
80+
def CUexternalMemory_getter(driver.CUexternalMemory x): return <uintptr_t><void*><cydriver.CUexternalMemory>(x._pvt_ptr[0])
81+
_handle_getters[driver.CUexternalMemory] = CUexternalMemory_getter
82+
{{endif}}
83+
{{if 'CUexternalSemaphore' in found_types}}
84+
def CUexternalSemaphore_getter(driver.CUexternalSemaphore x): return <uintptr_t><void*><cydriver.CUexternalSemaphore>(x._pvt_ptr[0])
85+
_handle_getters[driver.CUexternalSemaphore] = CUexternalSemaphore_getter
86+
{{endif}}
87+
{{if 'CUgraph' in found_types}}
88+
def CUgraph_getter(driver.CUgraph x): return <uintptr_t><void*><cydriver.CUgraph>(x._pvt_ptr[0])
89+
_handle_getters[driver.CUgraph] = CUgraph_getter
90+
{{endif}}
91+
{{if 'CUgraphNode' in found_types}}
92+
def CUgraphNode_getter(driver.CUgraphNode x): return <uintptr_t><void*><cydriver.CUgraphNode>(x._pvt_ptr[0])
93+
_handle_getters[driver.CUgraphNode] = CUgraphNode_getter
94+
{{endif}}
95+
{{if 'CUgraphExec' in found_types}}
96+
def CUgraphExec_getter(driver.CUgraphExec x): return <uintptr_t><void*><cydriver.CUgraphExec>(x._pvt_ptr[0])
97+
_handle_getters[driver.CUgraphExec] = CUgraphExec_getter
98+
{{endif}}
99+
{{if 'CUmemoryPool' in found_types}}
100+
def CUmemoryPool_getter(driver.CUmemoryPool x): return <uintptr_t><void*><cydriver.CUmemoryPool>(x._pvt_ptr[0])
101+
_handle_getters[driver.CUmemoryPool] = CUmemoryPool_getter
102+
{{endif}}
103+
{{if 'CUuserObject' in found_types}}
104+
def CUuserObject_getter(driver.CUuserObject x): return <uintptr_t><void*><cydriver.CUuserObject>(x._pvt_ptr[0])
105+
_handle_getters[driver.CUuserObject] = CUuserObject_getter
106+
{{endif}}
107+
{{if 'CUgraphDeviceNode' in found_types}}
108+
def CUgraphDeviceNode_getter(driver.CUgraphDeviceNode x): return <uintptr_t><void*><cydriver.CUgraphDeviceNode>(x._pvt_ptr[0])
109+
_handle_getters[driver.CUgraphDeviceNode] = CUgraphDeviceNode_getter
110+
{{endif}}
111+
{{if 'CUasyncCallbackHandle' in found_types}}
112+
def CUasyncCallbackHandle_getter(driver.CUasyncCallbackHandle x): return <uintptr_t><void*><cydriver.CUasyncCallbackHandle>(x._pvt_ptr[0])
113+
_handle_getters[driver.CUasyncCallbackHandle] = CUasyncCallbackHandle_getter
114+
{{endif}}
115+
{{if 'CUgreenCtx' in found_types}}
116+
def CUgreenCtx_getter(driver.CUgreenCtx x): return <uintptr_t><void*><cydriver.CUgreenCtx>(x._pvt_ptr[0])
117+
_handle_getters[driver.CUgreenCtx] = CUgreenCtx_getter
118+
{{endif}}
119+
{{if 'CUlinkState' in found_types}}
120+
def CUlinkState_getter(driver.CUlinkState x): return <uintptr_t><void*><cydriver.CUlinkState>(x._pvt_ptr[0])
121+
_handle_getters[driver.CUlinkState] = CUlinkState_getter
122+
{{endif}}
123+
{{if 'CUdevResourceDesc' in found_types}}
124+
def CUdevResourceDesc_getter(driver.CUdevResourceDesc x): return <uintptr_t><void*><cydriver.CUdevResourceDesc>(x._pvt_ptr[0])
125+
_handle_getters[driver.CUdevResourceDesc] = CUdevResourceDesc_getter
126+
{{endif}}
127+
{{if 'CUlogsCallbackHandle' in found_types}}
128+
def CUlogsCallbackHandle_getter(driver.CUlogsCallbackHandle x): return <uintptr_t><void*><cydriver.CUlogsCallbackHandle>(x._pvt_ptr[0])
129+
_handle_getters[driver.CUlogsCallbackHandle] = CUlogsCallbackHandle_getter
130+
{{endif}}
131+
{{if True}}
132+
def CUeglStreamConnection_getter(driver.CUeglStreamConnection x): return <uintptr_t><void*><cydriver.CUeglStreamConnection>(x._pvt_ptr[0])
133+
_handle_getters[driver.CUeglStreamConnection] = CUeglStreamConnection_getter
134+
{{endif}}
135+
{{if True}}
136+
def EGLImageKHR_getter(runtime.EGLImageKHR x): return <uintptr_t><void*><cyruntime.EGLImageKHR>(x._pvt_ptr[0])
137+
_handle_getters[runtime.EGLImageKHR] = EGLImageKHR_getter
138+
{{endif}}
139+
{{if True}}
140+
def EGLStreamKHR_getter(runtime.EGLStreamKHR x): return <uintptr_t><void*><cyruntime.EGLStreamKHR>(x._pvt_ptr[0])
141+
_handle_getters[runtime.EGLStreamKHR] = EGLStreamKHR_getter
142+
{{endif}}
143+
{{if True}}
144+
def EGLSyncKHR_getter(runtime.EGLSyncKHR x): return <uintptr_t><void*><cyruntime.EGLSyncKHR>(x._pvt_ptr[0])
145+
_handle_getters[runtime.EGLSyncKHR] = EGLSyncKHR_getter
146+
{{endif}}
147+
{{if 'cudaArray_t' in found_types}}
148+
def cudaArray_t_getter(runtime.cudaArray_t x): return <uintptr_t><void*><cyruntime.cudaArray_t>(x._pvt_ptr[0])
149+
_handle_getters[runtime.cudaArray_t] = cudaArray_t_getter
150+
{{endif}}
151+
{{if 'cudaArray_const_t' in found_types}}
152+
def cudaArray_const_t_getter(runtime.cudaArray_const_t x): return <uintptr_t><void*><cyruntime.cudaArray_const_t>(x._pvt_ptr[0])
153+
_handle_getters[runtime.cudaArray_const_t] = cudaArray_const_t_getter
154+
{{endif}}
155+
{{if 'cudaMipmappedArray_t' in found_types}}
156+
def cudaMipmappedArray_t_getter(runtime.cudaMipmappedArray_t x): return <uintptr_t><void*><cyruntime.cudaMipmappedArray_t>(x._pvt_ptr[0])
157+
_handle_getters[runtime.cudaMipmappedArray_t] = cudaMipmappedArray_t_getter
158+
{{endif}}
159+
{{if 'cudaMipmappedArray_const_t' in found_types}}
160+
def cudaMipmappedArray_const_t_getter(runtime.cudaMipmappedArray_const_t x): return <uintptr_t><void*><cyruntime.cudaMipmappedArray_const_t>(x._pvt_ptr[0])
161+
_handle_getters[runtime.cudaMipmappedArray_const_t] = cudaMipmappedArray_const_t_getter
162+
{{endif}}
163+
{{if 'cudaStream_t' in found_types}}
164+
def cudaStream_t_getter(runtime.cudaStream_t x): return <uintptr_t><void*><cyruntime.cudaStream_t>(x._pvt_ptr[0])
165+
_handle_getters[runtime.cudaStream_t] = cudaStream_t_getter
166+
{{endif}}
167+
{{if 'cudaEvent_t' in found_types}}
168+
def cudaEvent_t_getter(runtime.cudaEvent_t x): return <uintptr_t><void*><cyruntime.cudaEvent_t>(x._pvt_ptr[0])
169+
_handle_getters[runtime.cudaEvent_t] = cudaEvent_t_getter
170+
{{endif}}
171+
{{if 'cudaGraphicsResource_t' in found_types}}
172+
def cudaGraphicsResource_t_getter(runtime.cudaGraphicsResource_t x): return <uintptr_t><void*><cyruntime.cudaGraphicsResource_t>(x._pvt_ptr[0])
173+
_handle_getters[runtime.cudaGraphicsResource_t] = cudaGraphicsResource_t_getter
174+
{{endif}}
175+
{{if 'cudaExternalMemory_t' in found_types}}
176+
def cudaExternalMemory_t_getter(runtime.cudaExternalMemory_t x): return <uintptr_t><void*><cyruntime.cudaExternalMemory_t>(x._pvt_ptr[0])
177+
_handle_getters[runtime.cudaExternalMemory_t] = cudaExternalMemory_t_getter
178+
{{endif}}
179+
{{if 'cudaExternalSemaphore_t' in found_types}}
180+
def cudaExternalSemaphore_t_getter(runtime.cudaExternalSemaphore_t x): return <uintptr_t><void*><cyruntime.cudaExternalSemaphore_t>(x._pvt_ptr[0])
181+
_handle_getters[runtime.cudaExternalSemaphore_t] = cudaExternalSemaphore_t_getter
182+
{{endif}}
183+
{{if 'cudaGraph_t' in found_types}}
184+
def cudaGraph_t_getter(runtime.cudaGraph_t x): return <uintptr_t><void*><cyruntime.cudaGraph_t>(x._pvt_ptr[0])
185+
_handle_getters[runtime.cudaGraph_t] = cudaGraph_t_getter
186+
{{endif}}
187+
{{if 'cudaGraphNode_t' in found_types}}
188+
def cudaGraphNode_t_getter(runtime.cudaGraphNode_t x): return <uintptr_t><void*><cyruntime.cudaGraphNode_t>(x._pvt_ptr[0])
189+
_handle_getters[runtime.cudaGraphNode_t] = cudaGraphNode_t_getter
190+
{{endif}}
191+
{{if 'cudaUserObject_t' in found_types}}
192+
def cudaUserObject_t_getter(runtime.cudaUserObject_t x): return <uintptr_t><void*><cyruntime.cudaUserObject_t>(x._pvt_ptr[0])
193+
_handle_getters[runtime.cudaUserObject_t] = cudaUserObject_t_getter
194+
{{endif}}
195+
{{if 'cudaFunction_t' in found_types}}
196+
def cudaFunction_t_getter(runtime.cudaFunction_t x): return <uintptr_t><void*><cyruntime.cudaFunction_t>(x._pvt_ptr[0])
197+
_handle_getters[runtime.cudaFunction_t] = cudaFunction_t_getter
198+
{{endif}}
199+
{{if 'cudaKernel_t' in found_types}}
200+
def cudaKernel_t_getter(runtime.cudaKernel_t x): return <uintptr_t><void*><cyruntime.cudaKernel_t>(x._pvt_ptr[0])
201+
_handle_getters[runtime.cudaKernel_t] = cudaKernel_t_getter
202+
{{endif}}
203+
{{if 'cudaLibrary_t' in found_types}}
204+
def cudaLibrary_t_getter(runtime.cudaLibrary_t x): return <uintptr_t><void*><cyruntime.cudaLibrary_t>(x._pvt_ptr[0])
205+
_handle_getters[runtime.cudaLibrary_t] = cudaLibrary_t_getter
206+
{{endif}}
207+
{{if 'cudaMemPool_t' in found_types}}
208+
def cudaMemPool_t_getter(runtime.cudaMemPool_t x): return <uintptr_t><void*><cyruntime.cudaMemPool_t>(x._pvt_ptr[0])
209+
_handle_getters[runtime.cudaMemPool_t] = cudaMemPool_t_getter
210+
{{endif}}
211+
{{if 'cudaGraphExec_t' in found_types}}
212+
def cudaGraphExec_t_getter(runtime.cudaGraphExec_t x): return <uintptr_t><void*><cyruntime.cudaGraphExec_t>(x._pvt_ptr[0])
213+
_handle_getters[runtime.cudaGraphExec_t] = cudaGraphExec_t_getter
214+
{{endif}}
215+
{{if 'cudaGraphDeviceNode_t' in found_types}}
216+
def cudaGraphDeviceNode_t_getter(runtime.cudaGraphDeviceNode_t x): return <uintptr_t><void*><cyruntime.cudaGraphDeviceNode_t>(x._pvt_ptr[0])
217+
_handle_getters[runtime.cudaGraphDeviceNode_t] = cudaGraphDeviceNode_t_getter
218+
{{endif}}
219+
{{if 'cudaAsyncCallbackHandle_t' in found_types}}
220+
def cudaAsyncCallbackHandle_t_getter(runtime.cudaAsyncCallbackHandle_t x): return <uintptr_t><void*><cyruntime.cudaAsyncCallbackHandle_t>(x._pvt_ptr[0])
221+
_handle_getters[runtime.cudaAsyncCallbackHandle_t] = cudaAsyncCallbackHandle_t_getter
222+
{{endif}}
223+
{{if True}}
224+
def cudaEglStreamConnection_getter(runtime.cudaEglStreamConnection x): return <uintptr_t><void*><cyruntime.cudaEglStreamConnection>(x._pvt_ptr[0])
225+
_handle_getters[runtime.cudaEglStreamConnection] = cudaEglStreamConnection_getter
226+
{{endif}}
227+
try:
228+
return _handle_getters[obj_type](obj)
229+
except KeyError:
230+
raise TypeError("Unknown type: " + str(obj_type)) from None

cuda_bindings/docs/source/module/utils.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
44
.. module:: cuda.bindings.utils
55

6-
Utils module
7-
============
6+
utils
7+
=====
88

99
Functions
1010
---------
1111

1212
.. autosummary::
1313
:toctree: generated/
1414

15+
get_cuda_native_handle
1516
get_minimal_required_cuda_ver_from_ptx_ver
1617
get_ptx_ver

cuda_bindings/docs/source/release/12.X.Y-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ Released on MM DD, 2025
1212
Highlights
1313
----------
1414

15+
* A utility module :mod:`cuda.bindings.utils` is added
16+
17+
* Using ``int(cuda_obj)`` to retrieve the underlying address of a CUDA object is deprecated and
18+
subject to future removal. Please switch to use :func:`~cuda.bindings.utils.get_cuda_native_handle`
19+
instead.
20+
1521
* The ``cuda.bindings.cufile`` Python module was added, wrapping the
1622
`cuFile C APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html>`_.
1723
Supported on Linux only.

cuda_bindings/docs/source/tips_and_tricks.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ All CUDA C types are exposed to Python as Python classes. For example, the :clas
1111

1212
There is an important distinction between the ``getPtr()`` method and the behaviour of ``__int__()``. Since a ``CUstream`` is itself just a pointer, calling ``instance_of_CUstream.getPtr()`` returns the pointer *to* the pointer, instead of the value of the ``CUstream`` C object that is the pointer to the underlying stream handle. ``int(instance_of_CUstream)`` returns the value of the ``CUstream`` converted to a Python int and is the actual address of the underlying handle.
1313

14+
.. warning::
15+
16+
Using ``int(cuda_obj)`` to retrieve the underlying address of a CUDA object is deprecated and
17+
subject to future removal. Please switch to use :func:`~cuda.bindings.utils.get_cuda_native_handle`
18+
instead.
19+
1420

1521
Lifetime management of the CUDA objects
1622
=======================================

cuda_bindings/setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def generate_output(infile, local):
221221
os.path.join("cuda", "bindings", "_lib"),
222222
os.path.join("cuda", "bindings", "_lib", "cyruntime"),
223223
os.path.join("cuda", "bindings", "_internal"),
224+
os.path.join("cuda", "bindings", "utils"),
224225
]
225226
input_files = []
226227
for path in path_list:
@@ -287,6 +288,7 @@ def prep_extensions(sources, libraries):
287288

288289
# new path for the bindings from cybind
289290
def rename_architecture_specific_files():
291+
path = os.path.join("cuda", "bindings", "_internal")
290292
if sys.platform == "linux":
291293
src_files = glob.glob(os.path.join(path, "*_linux.pyx"))
292294
elif sys.platform == "win32":
@@ -341,6 +343,7 @@ def do_cythonize(extensions):
341343
(["cuda/bindings/_lib/utils.pyx", "cuda/bindings/_lib/param_packer.cpp"], None),
342344
(["cuda/bindings/_lib/cyruntime/cyruntime.pyx"], None),
343345
(["cuda/bindings/_lib/cyruntime/utils.pyx"], None),
346+
(["cuda/bindings/utils/*.pyx"], None),
344347
# public
345348
*(([f], None) for f in cuda_bindings_files),
346349
# public (deprecated, to be removed)

cuda_bindings/tests/test_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4+
import random
5+
46
import pytest
57

6-
from cuda.bindings.utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
8+
from cuda.bindings import driver, runtime
9+
from cuda.bindings.utils import get_cuda_native_handle, get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
710

811
ptx_88_kernel = r"""
912
.version 8.8
@@ -41,3 +44,46 @@ def test_ptx_utils(kernel, actual_ptx_ver, min_cuda_ver):
4144
assert ptx_ver == actual_ptx_ver
4245
cuda_ver = get_minimal_required_cuda_ver_from_ptx_ver(ptx_ver)
4346
assert cuda_ver == min_cuda_ver
47+
48+
49+
@pytest.mark.parametrize(
50+
"target",
51+
(
52+
driver.CUcontext,
53+
driver.CUstream,
54+
driver.CUevent,
55+
driver.CUmodule,
56+
driver.CUlibrary,
57+
driver.CUfunction,
58+
driver.CUkernel,
59+
driver.CUgraph,
60+
driver.CUgraphNode,
61+
driver.CUgraphExec,
62+
driver.CUmemoryPool,
63+
runtime.cudaStream_t,
64+
runtime.cudaEvent_t,
65+
runtime.cudaGraph_t,
66+
runtime.cudaGraphNode_t,
67+
runtime.cudaGraphExec_t,
68+
runtime.cudaMemPool_t,
69+
),
70+
)
71+
def test_get_handle(target):
72+
ptr = random.randint(1, 1024)
73+
obj = target(ptr)
74+
handle = get_cuda_native_handle(obj)
75+
assert handle == ptr
76+
77+
78+
@pytest.mark.parametrize(
79+
"target",
80+
(
81+
(1, 2, 3, 4),
82+
[5, 6],
83+
{},
84+
None,
85+
),
86+
)
87+
def test_get_handle_error(target):
88+
with pytest.raises(TypeError) as e:
89+
handle = get_cuda_native_handle(target)

0 commit comments

Comments
 (0)