From 8e56d6011d31afb1ca59d25acbd2400f0fd993d0 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 13 Jun 2025 14:52:21 +0000 Subject: [PATCH 01/15] cythonize event --- .../experimental/{_event.py => _event.pyx} | 78 +++++++++++-------- cuda_core/setup.py | 5 ++ 2 files changed, 52 insertions(+), 31 deletions(-) rename cuda_core/cuda/core/experimental/{_event.py => _event.pyx} (82%) diff --git a/cuda_core/cuda/core/experimental/_event.py b/cuda_core/cuda/core/experimental/_event.pyx similarity index 82% rename from cuda_core/cuda/core/experimental/_event.py rename to cuda_core/cuda/core/experimental/_event.pyx index 800f34c9a..1c1302a9b 100644 --- a/cuda_core/cuda/core/experimental/_event.py +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -4,14 +4,12 @@ from __future__ import annotations -import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Optional from cuda.core.experimental._context import Context from cuda.core.experimental._utils.cuda_utils import ( CUDAError, - check_or_create_options, driver, handle_return, ) @@ -25,7 +23,7 @@ @dataclass -class EventOptions: +cdef class EventOptions: """Customizable :obj:`~_event.Event` options. Attributes @@ -49,7 +47,27 @@ class EventOptions: support_ipc: Optional[bool] = False -class Event: +cdef inline EventOptions check_or_create_options(options, str options_description): + """ + Create the specified options dataclass from a dictionary of options or None. + """ + cdef EventOptions opts + if options is None: + opts = EventOptions() + elif isinstance(options, dict): + opts = EventOptions(**options) + elif not isinstance(options, EventOptions): + raise TypeError( + f"The {options_description} must be provided as an object " + f"of type {EventOptions.__name__} or as a dict with valid {options_description}. " + f"The provided object is '{options}'." + ) + + return opts + + + +cdef class Event: """Represent a record at a specific point of execution within a CUDA stream. Applications can asynchronously record events at any point in @@ -77,30 +95,20 @@ class Event: and they should instead be created through a :obj:`~_stream.Stream` object. """ - - class _MembersNeededForFinalize: - __slots__ = ("handle",) - - def __init__(self, event_obj, handle): - self.handle = handle - weakref.finalize(event_obj, self.close) - - def close(self): - if self.handle is not None: - handle_return(driver.cuEventDestroy(self.handle)) - self.handle = None - - def __new__(self, *args, **kwargs): + cdef: + object _handle + bint _timing_disabled + bint _busy_waited + int _device_id + object _ctx_handle + + def __init__(self, *args, **kwargs): raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).") - __slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited", "_device_id", "_ctx_handle") - @classmethod - def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptions] = None): - self = super().__new__(cls) - self._mnff = Event._MembersNeededForFinalize(self, None) - - options = check_or_create_options(EventOptions, options, "Event options") + def _init(cls, device_id: int, ctx_handle: Context, opts=None): + cdef Event self = Event.__new__(Event) + cdef EventOptions options = check_or_create_options(opts, "Event options") flags = 0x0 self._timing_disabled = False self._busy_waited = False @@ -112,14 +120,22 @@ def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptio self._busy_waited = True if options.support_ipc: raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103") - self._mnff.handle = handle_return(driver.cuEventCreate(flags)) + _, self._handle = driver.cuEventCreate(flags) self._device_id = device_id self._ctx_handle = ctx_handle return self + cdef _close(self): + if self._handle is not None: + _ = driver.cuEventDestroy(self._handle) + self._handle = None + def close(self): """Destroy the event.""" - self._mnff.close() + self._close() + + def __dealloc__(self): + self._close() def __isub__(self, other): return NotImplemented @@ -129,7 +145,7 @@ def __rsub__(self, other): def __sub__(self, other): # return self - other (in milliseconds) - err, timing = driver.cuEventElapsedTime(other.handle, self.handle) + err, timing = driver.cuEventElapsedTime(other.handle, self._handle) try: raise_if_driver_error(err) return timing @@ -180,12 +196,12 @@ def sync(self): has been completed. """ - handle_return(driver.cuEventSynchronize(self._mnff.handle)) + handle_return(driver.cuEventSynchronize(self._handle)) @property def is_done(self) -> bool: """Return True if all captured works have been completed, otherwise False.""" - (result,) = driver.cuEventQuery(self._mnff.handle) + (result,) = driver.cuEventQuery(self._handle) if result == driver.CUresult.CUDA_SUCCESS: return True if result == driver.CUresult.CUDA_ERROR_NOT_READY: @@ -201,7 +217,7 @@ def handle(self) -> cuda.bindings.driver.CUevent: This handle is a Python object. To get the memory address of the underlying C handle, call ``int(Event.handle)``. """ - return self._mnff.handle + return self._handle @property def device(self) -> Device: diff --git a/cuda_core/setup.py b/cuda_core/setup.py index f2005c3dd..2eabcf889 100644 --- a/cuda_core/setup.py +++ b/cuda_core/setup.py @@ -9,6 +9,11 @@ from setuptools.command.build_ext import build_ext as _build_ext ext_modules = ( + Extension( + "cuda.core.experimental._event", + sources=["cuda/core/experimental/_event.pyx"], + language="c++", + ), Extension( "cuda.core.experimental._dlpack", sources=["cuda/core/experimental/_dlpack.pyx"], From 9b96cc7cfa6539a818b7d10c11dffc3a92a5694c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 13 Jun 2025 16:30:49 +0000 Subject: [PATCH 02/15] cythonize context/event/util --- .../{_context.py => _context.pyx} | 21 +++++++----- cuda_core/cuda/core/experimental/_device.py | 6 ++-- .../_utils/{cuda_utils.py => cuda_utils.pyx} | 8 ++--- cuda_core/setup.py | 33 ++++++++----------- 4 files changed, 35 insertions(+), 33 deletions(-) rename cuda_core/cuda/core/experimental/{_context.py => _context.pyx} (56%) rename cuda_core/cuda/core/experimental/_utils/{cuda_utils.py => cuda_utils.pyx} (96%) diff --git a/cuda_core/cuda/core/experimental/_context.py b/cuda_core/cuda/core/experimental/_context.pyx similarity index 56% rename from cuda_core/cuda/core/experimental/_context.py rename to cuda_core/cuda/core/experimental/_context.pyx index 24e06d69c..205f6c983 100644 --- a/cuda_core/cuda/core/experimental/_context.py +++ b/cuda_core/cuda/core/experimental/_context.pyx @@ -13,16 +13,21 @@ class ContextOptions: pass # TODO -class Context: - __slots__ = ("_handle", "_id") +cdef class Context: - def __new__(self, *args, **kwargs): + cdef: + object _handle + int _device_id + + def __init__(self, *args, **kwargs): raise RuntimeError("Context objects cannot be instantiated directly. Please use Device or Stream APIs.") @classmethod - def _from_ctx(cls, obj, dev_id): - assert_type(obj, driver.CUcontext) - ctx = super().__new__(cls) - ctx._handle = obj - ctx._id = dev_id + def _from_ctx(cls, handle: driver.CUcontext, int device_id): + cdef Context ctx = Context.__new__(Context) + ctx._handle = handle + ctx._device_id = device_id return ctx + + def __eq__(self, other): + return int(self._handle) == int(other._handle) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index c9a786070..c89f659a9 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1237,7 +1237,6 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions """ return Stream._init(obj=obj, options=options) - @precondition(_check_context_initialized) def create_event(self, options: Optional[EventOptions] = None) -> Event: """Create an Event object without recording it to a Stream. @@ -1256,7 +1255,10 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event: Newly created event object. """ - return Event._init(self._id, self.context._handle, options) + ctx = driver.cuCtxGetCurrent()[1] + if int(ctx) == 0: + raise CUDAError("No context is bound to the calling CPU thread.") + return Event._init(self._id, ctx, options) @precondition(_check_context_initialized) def allocate(self, size, stream: Optional[Stream] = None) -> Buffer: diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx similarity index 96% rename from cuda_core/cuda/core/experimental/_utils/cuda_utils.py rename to cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx index 48b48d2fb..77ce533e6 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx @@ -52,7 +52,7 @@ def _reduce_3_tuple(t: tuple): return t[0] * t[1] * t[2] -def _check_driver_error(error): +cpdef inline void _check_driver_error(error) except*: if error == driver.CUresult.CUDA_SUCCESS: return name_err, name = driver.cuGetErrorName(error) @@ -69,7 +69,7 @@ def _check_driver_error(error): raise CUDAError(f"{name}: {desc}") -def _check_runtime_error(error): +cpdef inline void _check_runtime_error(error) except*: if error == runtime.cudaError_t.cudaSuccess: return name_err, name = runtime.cudaGetErrorName(error) @@ -86,7 +86,7 @@ def _check_runtime_error(error): raise CUDAError(f"{name}: {desc}") -def _check_error(error, handle=None): +cdef inline void _check_error(error, handle=None) except*: if isinstance(error, driver.CUresult): _check_driver_error(error) elif isinstance(error, runtime.cudaError_t): @@ -105,7 +105,7 @@ def _check_error(error, handle=None): raise RuntimeError(f"Unknown error type: {error}") -def handle_return(result, handle=None): +def handle_return(tuple result, handle=None): _check_error(result[0], handle=handle) if len(result) == 1: return diff --git a/cuda_core/setup.py b/cuda_core/setup.py index 2eabcf889..f2b84bfaf 100644 --- a/cuda_core/setup.py +++ b/cuda_core/setup.py @@ -2,33 +2,28 @@ # # SPDX-License-Identifier: Apache-2.0 +import glob import os from Cython.Build import cythonize from setuptools import Extension, setup from setuptools.command.build_ext import build_ext as _build_ext -ext_modules = ( - Extension( - "cuda.core.experimental._event", - sources=["cuda/core/experimental/_event.pyx"], - language="c++", - ), - Extension( - "cuda.core.experimental._dlpack", - sources=["cuda/core/experimental/_dlpack.pyx"], - language="c++", - ), - Extension( - "cuda.core.experimental._memoryview", - sources=["cuda/core/experimental/_memoryview.pyx"], - language="c++", - ), + +# It seems setuptools' wildcard support has problems for namespace packages, +# so we explicitly spell out all Extension instances. +root_module = "cuda.core.experimental" +root_path = f"{os.path.sep}".join(root_module.split(".")) + os.path.sep +ext_files = glob.glob(f"{root_path}/**/*.pyx", recursive=True) +def strip_prefix_suffix(filename): + return filename[len(root_path):-4] +module_names = (strip_prefix_suffix(f) for f in ext_files) +ext_modules = tuple( Extension( - "cuda.core.experimental._kernel_arg_handler", - sources=["cuda/core/experimental/_kernel_arg_handler.pyx"], + f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}", + sources=[f"cuda/core/experimental/{mod}.pyx"], language="c++", - ), + ) for mod in module_names ) From 3495a1f2ce138eb994209b28f8c7542085cb89d1 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 01:21:46 +0000 Subject: [PATCH 03/15] Cython 3.0+ supports __del__ for cdef classes --- cuda_core/cuda/core/experimental/_event.pyx | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 1c1302a9b..31c6756d1 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -125,17 +125,14 @@ cdef class Event: self._ctx_handle = ctx_handle return self - cdef _close(self): + cpdef close(self): + """Destroy the event.""" if self._handle is not None: _ = driver.cuEventDestroy(self._handle) self._handle = None - def close(self): - """Destroy the event.""" - self._close() - - def __dealloc__(self): - self._close() + def __del__(self): + self.close() def __isub__(self, other): return NotImplemented From 6f346e49c9b29701a4150818e19ce741f7a06ace Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 01:23:19 +0000 Subject: [PATCH 04/15] inline precondition to reduce overhead --- cuda_core/cuda/core/experimental/_device.py | 41 ++++++++++----- .../core/experimental/_utils/cuda_utils.pyx | 52 +++++++++++-------- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index c89f659a9..99aaefdbf 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -17,7 +17,6 @@ _check_driver_error, driver, handle_return, - precondition, runtime, ) @@ -1017,12 +1016,31 @@ def __new__(cls, device_id: Optional[int] = None): except IndexError: raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None - def _check_context_initialized(self, *args, **kwargs): + def _check_context_initialized(self): if not self._has_inited: raise CUDAError( f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?" ) + def _get_current_context(self, check_consistency=False) -> driver.CUcontext: + err, ctx = driver.cuCtxGetCurrent() + + # TODO: We want to just call this: + #_check_driver_error(err) + # but even the simplest success check causes 50-100 ns. Wait until we cythonize this file... + if ctx is None: + _check_driver_error(err) + + if int(ctx) == 0: + raise CUDAError("No context is bound to the calling CPU thread.") + if check_consistency: + err, dev = driver.cuCtxGetDevice() + if err != _SUCCESS: + handle_return((err,)) + if int(dev) != self._id: + raise CUDAError("Internal error (current device is not equal to Device.device_id)") + return ctx + @property def device_id(self) -> int: """Return device ordinal.""" @@ -1083,7 +1101,6 @@ def compute_capability(self) -> ComputeCapability: return cc @property - @precondition(_check_context_initialized) def context(self) -> Context: """Return the current :obj:`~_context.Context` associated with this device. @@ -1092,9 +1109,8 @@ def context(self) -> Context: Device must be initialized. """ - ctx = handle_return(driver.cuCtxGetCurrent()) - if int(ctx) == 0: - raise CUDAError("No context is bound to the calling CPU thread.") + self._check_context_initialized() + ctx = self._get_current_context(check_consistency=True) return Context._from_ctx(ctx, self._id) @property @@ -1206,7 +1222,6 @@ def create_context(self, options: ContextOptions = None) -> Context: """ raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189") - @precondition(_check_context_initialized) def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions = None) -> Stream: """Create a Stream object. @@ -1235,6 +1250,7 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions Newly created stream object. """ + self._check_context_initialized() return Stream._init(obj=obj, options=options) def create_event(self, options: Optional[EventOptions] = None) -> Event: @@ -1255,12 +1271,10 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event: Newly created event object. """ - ctx = driver.cuCtxGetCurrent()[1] - if int(ctx) == 0: - raise CUDAError("No context is bound to the calling CPU thread.") + self._check_context_initialized() + ctx = self._get_current_context() return Event._init(self._id, ctx, options) - @precondition(_check_context_initialized) def allocate(self, size, stream: Optional[Stream] = None) -> Buffer: """Allocate device memory from a specified stream. @@ -1287,11 +1301,11 @@ def allocate(self, size, stream: Optional[Stream] = None) -> Buffer: Newly created buffer object. """ + self._check_context_initialized() if stream is None: stream = default_stream() return self._mr.allocate(size, stream) - @precondition(_check_context_initialized) def sync(self): """Synchronize the device. @@ -1300,9 +1314,9 @@ def sync(self): Device must be initialized. """ + self._check_context_initialized() handle_return(runtime.cudaDeviceSynchronize()) - @precondition(_check_context_initialized) def create_graph_builder(self) -> GraphBuilder: """Create a new :obj:`~_graph.GraphBuilder` object. @@ -1312,4 +1326,5 @@ def create_graph_builder(self) -> GraphBuilder: Newly created graph builder object. """ + self._check_context_initialized() return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True) diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx index 77ce533e6..61ca4da9b 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx @@ -52,26 +52,29 @@ def _reduce_3_tuple(t: tuple): return t[0] * t[1] * t[2] -cpdef inline void _check_driver_error(error) except*: - if error == driver.CUresult.CUDA_SUCCESS: - return +cdef object _DRIVER_SUCCESS = driver.CUresult.CUDA_SUCCESS + + +cpdef inline int _check_driver_error(error) except?-1: + if error == _DRIVER_SUCCESS: + return 0 name_err, name = driver.cuGetErrorName(error) - if name_err != driver.CUresult.CUDA_SUCCESS: + if name_err != _DRIVER_SUCCESS: raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") name = name.decode() expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error)) if expl is not None: raise CUDAError(f"{name}: {expl}") desc_err, desc = driver.cuGetErrorString(error) - if desc_err != driver.CUresult.CUDA_SUCCESS: + if desc_err != _DRIVER_SUCCESS: raise CUDAError(f"{name}") desc = desc.decode() raise CUDAError(f"{name}: {desc}") -cpdef inline void _check_runtime_error(error) except*: +cpdef inline int _check_runtime_error(error) except?-1: if error == runtime.cudaError_t.cudaSuccess: - return + return 0 name_err, name = runtime.cudaGetErrorName(error) if name_err != runtime.cudaError_t.cudaSuccess: raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") @@ -86,30 +89,35 @@ cpdef inline void _check_runtime_error(error) except*: raise CUDAError(f"{name}: {desc}") -cdef inline void _check_error(error, handle=None) except*: +cpdef inline int _check_nvrtc_error(error, handle=None) except?-1: + if error == nvrtc.nvrtcResult.NVRTC_SUCCESS: + return 0 + err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}" + if handle is not None: + _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) + log = b" " * logsize + _ = nvrtc.nvrtcGetProgramLog(handle, log) + err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}" + raise NVRTCError(err) + + +cdef inline int _check_error(error, handle=None) except?-1: if isinstance(error, driver.CUresult): - _check_driver_error(error) + return _check_driver_error(error) elif isinstance(error, runtime.cudaError_t): - _check_runtime_error(error) + return _check_runtime_error(error) elif isinstance(error, nvrtc.nvrtcResult): - if error == nvrtc.nvrtcResult.NVRTC_SUCCESS: - return - err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}" - if handle is not None: - _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) - log = b" " * logsize - _ = nvrtc.nvrtcGetProgramLog(handle, log) - err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}" - raise NVRTCError(err) + return _check_nvrtc_error(error, handle=handle) else: raise RuntimeError(f"Unknown error type: {error}") def handle_return(tuple result, handle=None): _check_error(result[0], handle=handle) - if len(result) == 1: + cdef int out_len = len(result) + if out_len == 1: return - elif len(result) == 2: + elif out_len == 2: return result[1] else: return result[1:] @@ -144,7 +152,7 @@ def _handle_boolean_option(option: bool) -> str: return "true" if bool(option) else "false" -def precondition(checker: Callable[..., None], what: str = "") -> Callable: +def precondition(checker: Callable[..., None], str what="") -> Callable: """ A decorator that adds checks to ensure any preconditions are met. From 4282b999b0b03d2b12a9d9bceadb0ce8ab699cbc Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 02:58:43 +0000 Subject: [PATCH 05/15] centralize check_or_create_options --- cuda_core/cuda/core/experimental/_event.pyx | 30 ++++--------------- .../core/experimental/_utils/cuda_utils.pxd | 8 +++++ .../core/experimental/_utils/cuda_utils.pyx | 20 ++++++------- 3 files changed, 23 insertions(+), 35 deletions(-) create mode 100644 cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 31c6756d1..5e970c1aa 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -4,6 +4,11 @@ from __future__ import annotations +from cuda.core.experimental._utils.cuda_utils cimport ( + _check_driver_error as raise_if_driver_error, + check_or_create_options, +) + from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -13,9 +18,6 @@ from cuda.core.experimental._utils.cuda_utils import ( driver, handle_return, ) -from cuda.core.experimental._utils.cuda_utils import ( - _check_driver_error as raise_if_driver_error, -) if TYPE_CHECKING: import cuda.bindings @@ -47,26 +49,6 @@ cdef class EventOptions: support_ipc: Optional[bool] = False -cdef inline EventOptions check_or_create_options(options, str options_description): - """ - Create the specified options dataclass from a dictionary of options or None. - """ - cdef EventOptions opts - if options is None: - opts = EventOptions() - elif isinstance(options, dict): - opts = EventOptions(**options) - elif not isinstance(options, EventOptions): - raise TypeError( - f"The {options_description} must be provided as an object " - f"of type {EventOptions.__name__} or as a dict with valid {options_description}. " - f"The provided object is '{options}'." - ) - - return opts - - - cdef class Event: """Represent a record at a specific point of execution within a CUDA stream. @@ -108,7 +90,7 @@ cdef class Event: @classmethod def _init(cls, device_id: int, ctx_handle: Context, opts=None): cdef Event self = Event.__new__(Event) - cdef EventOptions options = check_or_create_options(opts, "Event options") + cdef EventOptions options = check_or_create_options(EventOptions, opts, "Event options") flags = 0x0 self._timing_disabled = False self._busy_waited = False diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd new file mode 100644 index 000000000..1dc6bd1eb --- /dev/null +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd @@ -0,0 +1,8 @@ +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + +cpdef int _check_driver_error(error) except?-1 +cpdef int _check_runtime_error(error) except?-1 +cpdef int _check_nvrtc_error(error) except?-1 +cpdef check_or_create_options(type cls, options, str options_description=*, bint keep_none=*) diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx index 61ca4da9b..3dc5601b5 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx @@ -1,12 +1,12 @@ # Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # -# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +# SPDX-License-Identifier: Apache-2.0 import functools import importlib.metadata from collections import namedtuple from collections.abc import Sequence -from typing import Callable, Dict +from typing import Callable try: from cuda.bindings import driver, nvrtc, runtime @@ -123,27 +123,25 @@ def handle_return(tuple result, handle=None): return result[1:] -def check_or_create_options(cls, options, options_description, *, keep_none=False): +cpdef check_or_create_options(type cls, options, str options_description="", bint keep_none=False): """ Create the specified options dataclass from a dictionary of options or None. """ - if options is None: if keep_none: return options - options = cls() - elif isinstance(options, Dict): - options = cls(**options) - - if not isinstance(options, cls): + return cls() + elif isinstance(options, cls): + return options + elif isinstance(options, dict): + return cls(**options) + else: raise TypeError( f"The {options_description} must be provided as an object " f"of type {cls.__name__} or as a dict with valid {options_description}. " f"The provided object is '{options}'." ) - return options - def _handle_boolean_option(option: bool) -> str: """ From 7b459540a8d343de10bee7d05b7f0614d481b17f Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 03:02:16 +0000 Subject: [PATCH 06/15] add back error check --- cuda_core/cuda/core/experimental/_event.pyx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 5e970c1aa..83e5e430e 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -102,7 +102,8 @@ cdef class Event: self._busy_waited = True if options.support_ipc: raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103") - _, self._handle = driver.cuEventCreate(flags) + err, self._handle = driver.cuEventCreate(flags) + raise_if_driver_error(err) self._device_id = device_id self._ctx_handle = ctx_handle return self @@ -110,7 +111,8 @@ cdef class Event: cpdef close(self): """Destroy the event.""" if self._handle is not None: - _ = driver.cuEventDestroy(self._handle) + err, = driver.cuEventDestroy(self._handle) + raise_if_driver_error(err) self._handle = None def __del__(self): @@ -180,7 +182,7 @@ cdef class Event: @property def is_done(self) -> bool: """Return True if all captured works have been completed, otherwise False.""" - (result,) = driver.cuEventQuery(self._handle) + result, = driver.cuEventQuery(self._handle) if result == driver.CUresult.CUDA_SUCCESS: return True if result == driver.CUresult.CUDA_ERROR_NOT_READY: From 3ff5e94c3854f85e1eb74e570ab2249a2b6df996 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 04:10:46 +0000 Subject: [PATCH 07/15] cythonize stream + bug fixes --- cuda_core/cuda/core/experimental/_context.pyx | 2 +- cuda_core/cuda/core/experimental/_device.py | 2 +- cuda_core/cuda/core/experimental/_event.pyx | 10 +- .../experimental/{_stream.py => _stream.pyx} | 129 +++++++++--------- cuda_core/tests/test_cuda_utils.py | 4 +- cuda_core/tests/test_stream.py | 2 +- 6 files changed, 77 insertions(+), 72 deletions(-) rename cuda_core/cuda/core/experimental/{_stream.py => _stream.pyx} (80%) diff --git a/cuda_core/cuda/core/experimental/_context.pyx b/cuda_core/cuda/core/experimental/_context.pyx index 205f6c983..d6abf65c1 100644 --- a/cuda_core/cuda/core/experimental/_context.pyx +++ b/cuda_core/cuda/core/experimental/_context.pyx @@ -16,7 +16,7 @@ class ContextOptions: cdef class Context: cdef: - object _handle + readonly object _handle int _device_id def __init__(self, *args, **kwargs): diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index 99aaefdbf..b5430eedb 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1251,7 +1251,7 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions """ self._check_context_initialized() - return Stream._init(obj=obj, options=options) + return Stream._init(obj=obj, options=options, device_id=self._id) def create_event(self, options: Optional[EventOptions] = None) -> Event: """Create an Event object without recording it to a Stream. diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 83e5e430e..d3907d389 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -88,19 +88,19 @@ cdef class Event: raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).") @classmethod - def _init(cls, device_id: int, ctx_handle: Context, opts=None): + def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptions] = None): cdef Event self = Event.__new__(Event) - cdef EventOptions options = check_or_create_options(EventOptions, opts, "Event options") + cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options") flags = 0x0 self._timing_disabled = False self._busy_waited = False - if not options.enable_timing: + if not opts.enable_timing: flags |= driver.CUevent_flags.CU_EVENT_DISABLE_TIMING self._timing_disabled = True - if options.busy_waited_sync: + if opts.busy_waited_sync: flags |= driver.CUevent_flags.CU_EVENT_BLOCKING_SYNC self._busy_waited = True - if options.support_ipc: + if opts.support_ipc: raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103") err, self._handle = driver.cuEventCreate(flags) raise_if_driver_error(err) diff --git a/cuda_core/cuda/core/experimental/_stream.py b/cuda_core/cuda/core/experimental/_stream.pyx similarity index 80% rename from cuda_core/cuda/core/experimental/_stream.py rename to cuda_core/cuda/core/experimental/_stream.pyx index ea488f9fc..165ea9908 100644 --- a/cuda_core/cuda/core/experimental/_stream.py +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -4,9 +4,13 @@ from __future__ import annotations +from cuda.core.experimental._utils.cuda_utils cimport ( + _check_driver_error as raise_if_driver_error, + check_or_create_options, +) + import os import warnings -import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Protocol, Tuple, Union @@ -18,16 +22,14 @@ from cuda.core.experimental._graph import GraphBuilder from cuda.core.experimental._utils.clear_error_support import assert_type from cuda.core.experimental._utils.cuda_utils import ( - check_or_create_options, driver, get_device_from_ctx, handle_return, - runtime, ) @dataclass -class StreamOptions: +cdef class StreamOptions: """Customizable :obj:`~_stream.Stream` options. Attributes @@ -85,7 +87,7 @@ def _try_to_get_stream_ptr(obj: IsStreamT): return driver.CUstream(info[1]) -class Stream: +cdef class Stream: """Represent a queue of GPU operations that are executed in a specific order. Applications use streams to control the order of execution for @@ -103,35 +105,27 @@ class Stream: """ - class _MembersNeededForFinalize: - __slots__ = ("handle", "owner", "builtin") - - def __init__(self, stream_obj, handle, owner, builtin): - self.handle = handle - self.owner = owner - self.builtin = builtin - weakref.finalize(stream_obj, self.close) + cdef: + object _handle + object _owner + object _builtin + object _nonblocking + object _priority + object _device_id + object _ctx_handle - def close(self): - if self.owner is None: - if self.handle and not self.builtin: - handle_return(driver.cuStreamDestroy(self.handle)) - else: - self.owner = None - self.handle = None - - def __new__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): raise RuntimeError( "Stream objects cannot be instantiated directly. " "Please use Device APIs (create_stream) or other Stream APIs (from_handle)." ) - __slots__ = ("__weakref__", "_mnff", "_nonblocking", "_priority", "_device_id", "_ctx_handle") - @classmethod def _legacy_default(cls): - self = super().__new__(cls) - self._mnff = Stream._MembersNeededForFinalize(self, driver.CUstream(driver.CU_STREAM_LEGACY), None, True) + cdef Stream self = Stream.__new__(Stream) + self._handle = driver.CUstream(driver.CU_STREAM_LEGACY) + self._owner = None + self._builtin = True self._nonblocking = None # delayed self._priority = None # delayed self._device_id = None # delayed @@ -140,8 +134,10 @@ def _legacy_default(cls): @classmethod def _per_thread_default(cls): - self = super().__new__(cls) - self._mnff = Stream._MembersNeededForFinalize(self, driver.CUstream(driver.CU_STREAM_PER_THREAD), None, True) + cdef Stream self = Stream.__new__(Stream) + self._handle = driver.CUstream(driver.CU_STREAM_PER_THREAD) + self._owner = None + self._builtin = True self._nonblocking = None # delayed self._priority = None # delayed self._device_id = None # delayed @@ -149,57 +145,65 @@ def _per_thread_default(cls): return self @classmethod - def _init(cls, obj: Optional[IsStreamT] = None, *, options: Optional[StreamOptions] = None): - self = super().__new__(cls) - self._mnff = Stream._MembersNeededForFinalize(self, None, None, False) + def _init(cls, obj: Optional[IsStreamT] = None, *, options: Optional[StreamOptions] = None, device_id: int = None): + cdef Stream self = Stream.__new__(Stream) + self._handle = None + self._owner = None + self._builtin = False if obj is not None and options is not None: raise ValueError("obj and options cannot be both specified") if obj is not None: - self._mnff.handle = _try_to_get_stream_ptr(obj) + self._handle = _try_to_get_stream_ptr(obj) # TODO: check if obj is created under the current context/device - self._mnff.owner = obj + self._owner = obj self._nonblocking = None # delayed self._priority = None # delayed self._device_id = None # delayed self._ctx_handle = None # delayed return self - options = check_or_create_options(StreamOptions, options, "Stream options") - nonblocking = options.nonblocking - priority = options.priority + cdef StreamOptions opts = check_or_create_options(StreamOptions, options, "Stream options") + nonblocking = opts.nonblocking + priority = opts.priority flags = driver.CUstream_flags.CU_STREAM_NON_BLOCKING if nonblocking else driver.CUstream_flags.CU_STREAM_DEFAULT - - high, low = handle_return(runtime.cudaDeviceGetStreamPriorityRange()) + err, high, low = driver.cuCtxGetStreamPriorityRange() + raise_if_driver_error(err) if priority is not None: if not (low <= priority <= high): raise ValueError(f"{priority=} is out of range {[low, high]}") else: priority = high - self._mnff.handle = handle_return(driver.cuStreamCreateWithPriority(flags, priority)) - self._mnff.owner = None + self._handle = handle_return(driver.cuStreamCreateWithPriority(flags, priority)) + self._owner = None self._nonblocking = nonblocking self._priority = priority - # don't defer this because we will have to pay a cost for context - # switch later - self._device_id = int(handle_return(driver.cuCtxGetDevice())) + self._device_id = device_id self._ctx_handle = None # delayed return self - def close(self): + def __del__(self): + self.close() + + cpdef close(self): """Destroy the stream. Destroy the stream if we own it. Borrowed foreign stream object will instead have their references released. """ - self._mnff.close() + if self._owner is None: + if self._handle and not self._builtin: + handle_return(driver.cuStreamDestroy(self._handle)) + else: + self._owner = None + self._handle = None def __cuda_stream__(self) -> Tuple[int, int]: """Return an instance of a __cuda_stream__ protocol.""" - return (0, self.handle) + return (0, int(self.handle)) @property def handle(self) -> cuda.bindings.driver.CUstream: @@ -210,13 +214,13 @@ def handle(self) -> cuda.bindings.driver.CUstream: This handle is a Python object. To get the memory address of the underlying C handle, call ``int(Stream.handle)``. """ - return self._mnff.handle + return self._handle @property def is_nonblocking(self) -> bool: """Return True if this is a nonblocking stream, otherwise False.""" if self._nonblocking is None: - flag = handle_return(driver.cuStreamGetFlags(self._mnff.handle)) + flag = handle_return(driver.cuStreamGetFlags(self._handle)) if flag == driver.CUstream_flags.CU_STREAM_NON_BLOCKING: self._nonblocking = True else: @@ -227,13 +231,13 @@ def is_nonblocking(self) -> bool: def priority(self) -> int: """Return the stream priority.""" if self._priority is None: - prio = handle_return(driver.cuStreamGetPriority(self._mnff.handle)) + prio = handle_return(driver.cuStreamGetPriority(self._handle)) self._priority = prio return self._priority def sync(self): """Synchronize the stream.""" - handle_return(driver.cuStreamSynchronize(self._mnff.handle)) + handle_return(driver.cuStreamSynchronize(self._handle)) def record(self, event: Event = None, options: EventOptions = None) -> Event: """Record an event onto the stream. @@ -259,8 +263,8 @@ def record(self, event: Event = None, options: EventOptions = None) -> Event: # and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions. if event is None: event = Event._init(self._device_id, self._ctx_handle, options) - assert_type(event, Event) - handle_return(driver.cuEventRecord(event.handle, self._mnff.handle)) + err, = driver.cuEventRecord(event.handle, self._handle) + raise_if_driver_error(err) return event def wait(self, event_or_stream: Union[Event, Stream]): @@ -281,7 +285,7 @@ def wait(self, event_or_stream: Union[Event, Stream]): stream = event_or_stream else: try: - stream = Stream._init(event_or_stream) + stream = Stream._init(obj=event_or_stream) except Exception as e: raise ValueError( "only an Event, Stream, or object supporting __cuda_stream__ can be waited," @@ -292,7 +296,7 @@ def wait(self, event_or_stream: Union[Event, Stream]): discard_event = True # TODO: support flags other than 0? - handle_return(driver.cuStreamWaitEvent(self._mnff.handle, event, 0)) + handle_return(driver.cuStreamWaitEvent(self._handle, event, 0)) if discard_event: handle_return(driver.cuEventDestroy(event)) @@ -308,21 +312,22 @@ def device(self) -> Device: """ from cuda.core.experimental._device import Device # avoid circular import + self._get_device_and_context() + return Device(self._device_id) + cdef _get_device_and_context(self): + # Get the stream context first + if self._ctx_handle is None: + err, self._ctx_handle = driver.cuStreamGetCtx(self._handle) + raise_if_driver_error(err) if self._device_id is None: - # Get the stream context first - if self._ctx_handle is None: - self._ctx_handle = handle_return(driver.cuStreamGetCtx(self._mnff.handle)) self._device_id = get_device_from_ctx(self._ctx_handle) - return Device(self._device_id) + raise_if_driver_error(err) @property def context(self) -> Context: """Return the :obj:`~_context.Context` associated with this stream.""" - if self._ctx_handle is None: - self._ctx_handle = handle_return(driver.cuStreamGetCtx(self._mnff.handle)) - if self._device_id is None: - self._device_id = get_device_from_ctx(self._ctx_handle) + self._get_device_and_context() return Context._from_ctx(self._ctx_handle, self._device_id) @staticmethod diff --git a/cuda_core/tests/test_cuda_utils.py b/cuda_core/tests/test_cuda_utils.py index 5f94e545f..77d9e457a 100644 --- a/cuda_core/tests/test_cuda_utils.py +++ b/cuda_core/tests/test_cuda_utils.py @@ -44,7 +44,7 @@ def test_check_driver_error(): num_unexpected = 0 for error in driver.CUresult: if error == driver.CUresult.CUDA_SUCCESS: - assert cuda_utils._check_driver_error(error) is None + assert cuda_utils._check_driver_error(error) == 0 else: with pytest.raises(cuda_utils.CUDAError) as e: cuda_utils._check_driver_error(error) @@ -63,7 +63,7 @@ def test_check_runtime_error(): num_unexpected = 0 for error in runtime.cudaError_t: if error == runtime.cudaError_t.cudaSuccess: - assert cuda_utils._check_runtime_error(error) is None + assert cuda_utils._check_runtime_error(error) == 0 else: with pytest.raises(cuda_utils.CUDAError) as e: cuda_utils._check_runtime_error(error) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index a73655f1a..daa8dac6b 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -52,7 +52,7 @@ def test_stream_record(init_cuda): def test_stream_record_invalid_event(init_cuda): stream = Device().create_stream(options=StreamOptions()) - with pytest.raises(TypeError): + with pytest.raises(AttributeError): stream.record(event="invalid_event") From 227d9c1c38c5c38a0ab32f70a009e5e2f80d34cb Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 04:13:20 +0000 Subject: [PATCH 08/15] make linter happy --- cuda_core/cuda/core/experimental/_device.py | 2 +- cuda_core/setup.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index b5430eedb..f0c300c2a 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1026,7 +1026,7 @@ def _get_current_context(self, check_consistency=False) -> driver.CUcontext: err, ctx = driver.cuCtxGetCurrent() # TODO: We want to just call this: - #_check_driver_error(err) + # _check_driver_error(err) # but even the simplest success check causes 50-100 ns. Wait until we cythonize this file... if ctx is None: _check_driver_error(err) diff --git a/cuda_core/setup.py b/cuda_core/setup.py index f2b84bfaf..26b9a5ab0 100644 --- a/cuda_core/setup.py +++ b/cuda_core/setup.py @@ -9,21 +9,25 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext as _build_ext - # It seems setuptools' wildcard support has problems for namespace packages, # so we explicitly spell out all Extension instances. root_module = "cuda.core.experimental" root_path = f"{os.path.sep}".join(root_module.split(".")) + os.path.sep ext_files = glob.glob(f"{root_path}/**/*.pyx", recursive=True) + + def strip_prefix_suffix(filename): - return filename[len(root_path):-4] + return filename[len(root_path) : -4] + + module_names = (strip_prefix_suffix(f) for f in ext_files) ext_modules = tuple( Extension( f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}", sources=[f"cuda/core/experimental/{mod}.pyx"], language="c++", - ) for mod in module_names + ) + for mod in module_names ) From e96bb4a5770181ac1d6ffe80296c1c0b6cef538c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 04:31:24 +0000 Subject: [PATCH 09/15] bug fix --- cuda_core/cuda/core/experimental/_stream.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/cuda/core/experimental/_stream.pyx b/cuda_core/cuda/core/experimental/_stream.pyx index 165ea9908..19bd8cc1f 100644 --- a/cuda_core/cuda/core/experimental/_stream.pyx +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -262,6 +262,7 @@ cdef class Stream: # on the stream. Event flags such as disabling timing, nonblocking, # and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions. if event is None: + self._get_device_and_context() event = Event._init(self._device_id, self._ctx_handle, options) err, = driver.cuEventRecord(event.handle, self._handle) raise_if_driver_error(err) From 48de1b3518733adffec10502b385721eb97681d8 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 18:50:12 +0000 Subject: [PATCH 10/15] Cython mis-compiles Optional types --- cuda_core/cuda/core/experimental/_device.py | 2 +- cuda_core/cuda/core/experimental/_event.pyx | 2 +- cuda_core/cuda/core/experimental/_stream.pyx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index f0c300c2a..1a71d998f 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1222,7 +1222,7 @@ def create_context(self, options: ContextOptions = None) -> Context: """ raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189") - def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions = None) -> Stream: + def create_stream(self, obj: Optional[IsStreamT] = None, options: Optional[StreamOptions] = None) -> Stream: """Create a Stream object. New stream objects can be created in two different ways: diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index d3907d389..74ac2bb89 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -88,7 +88,7 @@ cdef class Event: raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).") @classmethod - def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptions] = None): + def _init(cls, device_id: int, ctx_handle: Context, options=None): cdef Event self = Event.__new__(Event) cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options") flags = 0x0 diff --git a/cuda_core/cuda/core/experimental/_stream.pyx b/cuda_core/cuda/core/experimental/_stream.pyx index 19bd8cc1f..277f11c05 100644 --- a/cuda_core/cuda/core/experimental/_stream.pyx +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -145,7 +145,7 @@ cdef class Stream: return self @classmethod - def _init(cls, obj: Optional[IsStreamT] = None, *, options: Optional[StreamOptions] = None, device_id: int = None): + def _init(cls, obj: Optional[IsStreamT] = None, options=None, device_id: int = None): cdef Stream self = Stream.__new__(Stream) self._handle = None self._owner = None From b305ea22fdb3d2872d61cb136f5bf0fc6054fccb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:10:27 -0500 Subject: [PATCH 11/15] Modified _get_context_device() helper routine To fix test failures with CTK 11.8 and driver 535.247.01 only attempt to query _ctx_handle if _device_id is None. Ensure that context handle is set in Stream.context property --- cuda_core/cuda/core/experimental/_stream.pyx | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_stream.pyx b/cuda_core/cuda/core/experimental/_stream.pyx index 277f11c05..168d968ec 100644 --- a/cuda_core/cuda/core/experimental/_stream.pyx +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -316,18 +316,21 @@ cdef class Stream: self._get_device_and_context() return Device(self._device_id) - cdef _get_device_and_context(self): - # Get the stream context first + cdef void _get_context(Stream self) except *: if self._ctx_handle is None: err, self._ctx_handle = driver.cuStreamGetCtx(self._handle) raise_if_driver_error(err) + + cdef void _get_device_and_context(Stream self) except *: if self._device_id is None: + # Get the stream context first + self._get_context() self._device_id = get_device_from_ctx(self._ctx_handle) - raise_if_driver_error(err) @property def context(self) -> Context: """Return the :obj:`~_context.Context` associated with this stream.""" + self._get_context() self._get_device_and_context() return Context._from_ctx(self._ctx_handle, self._device_id) From 30de7203b98f5e3f1e67f37ea8182eca08df7b9c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:04:37 -0500 Subject: [PATCH 12/15] Verify that stream.context handle is not None in test_stream_context --- cuda_core/tests/test_stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index daa8dac6b..7a3ff8b2c 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -80,6 +80,7 @@ def test_stream_context(init_cuda): stream = Device().create_stream(options=StreamOptions()) context = stream.context assert context is not None + assert context._handle is not None def test_stream_from_foreign_stream(init_cuda): From cc6339ee85db759005bfc96583a18006ef095765 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Mon, 30 Jun 2025 22:11:27 +0000 Subject: [PATCH 13/15] cache success enums --- cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx index 3dc5601b5..f14addd43 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx @@ -53,6 +53,8 @@ def _reduce_3_tuple(t: tuple): cdef object _DRIVER_SUCCESS = driver.CUresult.CUDA_SUCCESS +cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess +cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS cpdef inline int _check_driver_error(error) except?-1: @@ -73,24 +75,24 @@ cpdef inline int _check_driver_error(error) except?-1: cpdef inline int _check_runtime_error(error) except?-1: - if error == runtime.cudaError_t.cudaSuccess: + if error == _RUNTIME_SUCCESS: return 0 name_err, name = runtime.cudaGetErrorName(error) - if name_err != runtime.cudaError_t.cudaSuccess: + if name_err != _RUNTIME_SUCCESS: raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") name = name.decode() expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error)) if expl is not None: raise CUDAError(f"{name}: {expl}") desc_err, desc = runtime.cudaGetErrorString(error) - if desc_err != runtime.cudaError_t.cudaSuccess: + if desc_err != _RUNTIME_SUCCESS: raise CUDAError(f"{name}") desc = desc.decode() raise CUDAError(f"{name}: {desc}") cpdef inline int _check_nvrtc_error(error, handle=None) except?-1: - if error == nvrtc.nvrtcResult.NVRTC_SUCCESS: + if error == _NVRTC_SUCCESS: return 0 err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}" if handle is not None: From e95c4b1886b32d708852c5d9e2e56f714ac56b32 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Mon, 30 Jun 2025 22:45:43 +0000 Subject: [PATCH 14/15] nit: avoid cdef void --- cuda_core/cuda/core/experimental/_stream.pyx | 6 ++++-- cuda_core/cuda/core/experimental/_utils/__init__.pxd | 0 2 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 cuda_core/cuda/core/experimental/_utils/__init__.pxd diff --git a/cuda_core/cuda/core/experimental/_stream.pyx b/cuda_core/cuda/core/experimental/_stream.pyx index 168d968ec..dc8d8e942 100644 --- a/cuda_core/cuda/core/experimental/_stream.pyx +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -316,16 +316,18 @@ cdef class Stream: self._get_device_and_context() return Device(self._device_id) - cdef void _get_context(Stream self) except *: + cdef int _get_context(Stream self) except?-1: if self._ctx_handle is None: err, self._ctx_handle = driver.cuStreamGetCtx(self._handle) raise_if_driver_error(err) + return 0 - cdef void _get_device_and_context(Stream self) except *: + cdef int _get_device_and_context(Stream self) except?-1: if self._device_id is None: # Get the stream context first self._get_context() self._device_id = get_device_from_ctx(self._ctx_handle) + return 0 @property def context(self) -> Context: diff --git a/cuda_core/cuda/core/experimental/_utils/__init__.pxd b/cuda_core/cuda/core/experimental/_utils/__init__.pxd new file mode 100644 index 000000000..e69de29bb From f4531e5a141d29e5d2181e3cda620a833f6012c9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Tue, 1 Jul 2025 08:46:32 -0500 Subject: [PATCH 15/15] In Event.close set handle to None before raising error --- cuda_core/cuda/core/experimental/_event.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 74ac2bb89..944969852 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -112,8 +112,8 @@ cdef class Event: """Destroy the event.""" if self._handle is not None: err, = driver.cuEventDestroy(self._handle) - raise_if_driver_error(err) self._handle = None + raise_if_driver_error(err) def __del__(self): self.close()