From 07e4eb5f28f1778b0ea612f922c3aad61a7a8fc7 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 3 Jun 2025 05:25:14 +0000 Subject: [PATCH 1/5] support cooperative launch --- cuda_core/cuda/core/experimental/_device.py | 11 +++++ .../cuda/core/experimental/_launch_config.py | 11 +++++ cuda_core/cuda/core/experimental/_launcher.py | 17 +++++++ .../core/experimental/_utils/cuda_utils.py | 4 ++ cuda_core/docs/source/release/0.3.0-notes.rst | 3 +- cuda_core/tests/test_launcher.py | 45 +++++++++++++++++++ 6 files changed, 90 insertions(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index 55afb0eba..ab31b014e 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -701,6 +701,17 @@ def can_use_host_pointer_for_registered_mem(self) -> bool: ) ) + # TODO: A few attrs are missing here (NVIDIA/cuda-python#675) + + @property + def cooperative_launch(self) -> bool: + """ + True if device supports launching cooperative kernels, False if not. + """ + return bool(self._get_cached_attribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH)) + + # TODO: A few attrs are missing here (NVIDIA/cuda-python#675) + @property def max_shared_memory_per_block_optin(self) -> int: """ diff --git a/cuda_core/cuda/core/experimental/_launch_config.py b/cuda_core/cuda/core/experimental/_launch_config.py index bb4e92fb3..43b7c2109 100644 --- a/cuda_core/cuda/core/experimental/_launch_config.py +++ b/cuda_core/cuda/core/experimental/_launch_config.py @@ -58,11 +58,15 @@ class LaunchConfig: cluster: Union[tuple, int] = None block: Union[tuple, int] = None shmem_size: Optional[int] = None + cooperative_launch: Optional[bool] = False def __post_init__(self): _lazy_init() self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid) self.block = cast_to_3_tuple("LaunchConfig.block", self.block) + # FIXME: Calling Device() strictly speaking is not quite right; we should instead + # look up the device from stream. We probably need to defer the checks related to + # device compute capability or attributes. # thread block clusters are supported starting H100 if self.cluster is not None: if not _use_ex: @@ -77,6 +81,8 @@ def __post_init__(self): self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster) if self.shmem_size is None: self.shmem_size = 0 + if self.cooperative_launch and not Device().properties.cooperative_launch: + raise CUDAError("cooperative kernels are not supported on this device") def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig: @@ -92,6 +98,11 @@ def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig: dim = attr.value.clusterDim dim.x, dim.y, dim.z = config.cluster attrs.append(attr) + if config.cooperative_launch: + attr = driver.CUlaunchAttribute() + attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE + attr.value.cooperative = 1 + attrs.append(attr) drv_cfg.numAttrs = len(attrs) drv_cfg.attrs = attrs return drv_cfg diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 72afb5ffd..894d661c2 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -9,6 +9,7 @@ from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.clear_error_support import assert_type from cuda.core.experimental._utils.cuda_utils import ( + _reduce_tuple, check_or_create_options, driver, get_binding_version, @@ -78,6 +79,8 @@ def launch(stream, config, kernel, *kernel_args): if _use_ex: drv_cfg = _to_native_launch_config(config) drv_cfg.hStream = stream.handle + if config.cooperative_launch: + _check_cooperative_launch(kernel, config, stream) handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) else: # TODO: check if config has any unsupported attrs @@ -86,3 +89,17 @@ def launch(stream, config, kernel, *kernel_args): int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0 ) ) + + +def _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream): + dev = stream.device + num_sm = dev.properties.multiprocessor_count + max_grid_size = ( + kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_tuple(config.block), config.shmem_size) * num_sm + ) + if _reduce_tuple(config.grid) > max_grid_size: + # For now let's try not to be smart and adjust the grid size behind users' back. + # We explicitly ask users to adjust. + raise ValueError( + "The specified grid size ({} * {} * {}) exceeds the limit ({}).".format(*config.grid, max_grid_size) + ) diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py b/cuda_core/cuda/core/experimental/_utils/cuda_utils.py index 7cf9be31d..0c8635b6e 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.py @@ -48,6 +48,10 @@ def cast_to_3_tuple(label, cfg): return cfg + (1,) * (3 - len(cfg)) +def _reduce_tuple(t: tuple): + return functools.reduce(lambda x, y: x * y, t, 1) + + def _check_driver_error(error): if error == driver.CUresult.CUDA_SUCCESS: return diff --git a/cuda_core/docs/source/release/0.3.0-notes.rst b/cuda_core/docs/source/release/0.3.0-notes.rst index 88a028b51..2d32a3890 100644 --- a/cuda_core/docs/source/release/0.3.0-notes.rst +++ b/cuda_core/docs/source/release/0.3.0-notes.rst @@ -22,6 +22,7 @@ New features - :class:`Kernel` adds :property:`Kernel.num_arguments` and :property:`Kernel.arguments_info` for introspection of kernel arguments. (#612) - Add pythonic access to kernel occupancy calculation functions via :property:`Kernel.occupancy`. (#648) +- Support launching cooperative kernels by setting :property:`LaunchConfig.cooperative_launch` to `True`. New examples ------------ @@ -31,4 +32,4 @@ Fixes and enhancements ---------------------- - An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively. -- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed +- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed. diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 9c72693a1..260b18a43 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -152,3 +152,48 @@ def test_launch_scalar_argument(python_type, cpp_type, init_value): # Check result assert arr[0] == init_value, f"Expected {init_value}, got {arr[0]}" + + +@pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need cg header") +def test_cooperative_launch(): + dev = Device() + dev.set_current() + s = dev.create_stream(options={"nonblocking": True}) + + # CUDA kernel templated on type T + code = r""" + #include + + extern "C" __global__ void test_grid_sync() { + namespace cg = cooperative_groups; + auto grid = cg::this_grid(); + grid.sync(); + } + """ + + # Compile and force instantiation for this type + arch = "".join(f"{i}" for i in dev.compute_capability) + include_path = str(pathlib.Path(os.environ["CUDA_PATH"]) / pathlib.Path("include")) + pro_opts = ProgramOptions(std="c++17", arch=f"sm_{arch}", include_path=include_path) + prog = Program(code, code_type="c++", options=pro_opts) + ker = prog.compile("cubin").get_kernel("test_grid_sync") + + # # Launch without setting cooperative_launch + # # Commented out as this seems to be a sticky error... + # config = LaunchConfig(grid=1, block=1) + # launch(s, config, ker) + # from cuda.core.experimental._utils.cuda_utils import CUDAError + # with pytest.raises(CUDAError) as e: + # s.sync() + # assert "CUDA_ERROR_LAUNCH_FAILED" in str(e) + + # Crazy grid sizes would not work + block = 128 + config = LaunchConfig(grid=dev.properties.max_grid_dim_x // block + 1, block=block, cooperative_launch=True) + with pytest.raises(ValueError): + launch(s, config, ker) + + # This works just fine + config = LaunchConfig(grid=1, block=1, cooperative_launch=True) + launch(s, config, ker) + s.sync() From b717e5309166813777d9c21c4e1a80262a28a908 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 4 Jun 2025 03:35:59 +0000 Subject: [PATCH 2/5] update dev attr test --- cuda_core/tests/test_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 309d661f5..e59ca56b5 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -191,6 +191,7 @@ def test_compute_capability(): ("concurrent_managed_access", bool), ("compute_preemption_supported", bool), ("can_use_host_pointer_for_registered_mem", bool), + ("cooperative_launch", bool), ("max_shared_memory_per_block_optin", int), ("pageable_memory_access_uses_host_page_tables", bool), ("direct_managed_mem_access_from_host", bool), From 138f0d106ad157b1dd8e9f304f8aa3f6ad85039f Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 6 Jun 2025 02:42:16 +0000 Subject: [PATCH 3/5] use f-string --- cuda_core/cuda/core/experimental/_launcher.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 894d661c2..82ab8eda8 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -100,6 +100,5 @@ def _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stre if _reduce_tuple(config.grid) > max_grid_size: # For now let's try not to be smart and adjust the grid size behind users' back. # We explicitly ask users to adjust. - raise ValueError( - "The specified grid size ({} * {} * {}) exceeds the limit ({}).".format(*config.grid, max_grid_size) - ) + x, y, z = config.grid + raise ValueError(f"The specified grid size ({x} * {y} * {z}) exceeds the limit ({max_grid_size})") From 167136e3e1a5a9e45a859343377a6af0c069a82e Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 6 Jun 2025 02:48:15 +0000 Subject: [PATCH 4/5] add a naive skipif_need_cuda_headers fixture --- cuda_core/tests/conftest.py | 4 ++++ cuda_core/tests/test_event.py | 5 +++-- cuda_core/tests/test_launcher.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 5bd5b52b7..cc1d6699b 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -86,3 +86,7 @@ def clean_up_cffi_files(): os.environ.get("CUDA_PYTHON_TESTING_WITH_COMPUTE_SANITIZER", "0") == "1", reason="The compute-sanitizer is running, and this test causes an API error.", ) + + +# TODO: make the fixture more sophisticated using path finder +skipif_need_cuda_headers = pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need CUDA header") diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index eb954070e..6cbcc3f06 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -13,6 +13,8 @@ from cuda.core.experimental import Device, EventOptions, LaunchConfig, Program, ProgramOptions, launch from cuda.core.experimental._memory import _DefaultPinnedMemorySource +from conftest import skipif_need_cuda_headers + def test_event_init_disabled(): with pytest.raises(RuntimeError, match=r"^Event objects cannot be instantiated directly\."): @@ -114,9 +116,8 @@ def test_error_timing_recorded(): event3 - event2 -# TODO: improve this once path finder can find headers @skipif_testing_with_compute_sanitizer -@pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need libcu++ header") +@skipif_need_cuda_headers # libcu++ @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_error_timing_incomplete(): device = Device() diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 260b18a43..62a0225e6 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -11,6 +11,8 @@ from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch from cuda.core.experimental._memory import _DefaultPinnedMemorySource +from conftest import skipif_need_cuda_headers + def test_launch_config_init(init_cuda): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=0) @@ -154,7 +156,7 @@ def test_launch_scalar_argument(python_type, cpp_type, init_value): assert arr[0] == init_value, f"Expected {init_value}, got {arr[0]}" -@pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need cg header") +@skipif_need_cuda_headers # cg def test_cooperative_launch(): dev = Device() dev.set_current() From 447c40d7d210bec945276348b7dbe8864f607e07 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 6 Jun 2025 02:53:30 +0000 Subject: [PATCH 5/5] micro-optimization --- cuda_core/cuda/core/experimental/_launcher.py | 6 +++--- cuda_core/cuda/core/experimental/_utils/cuda_utils.py | 4 ++-- cuda_core/tests/test_event.py | 4 +--- cuda_core/tests/test_launcher.py | 3 +-- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 82ab8eda8..1177d6034 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -9,7 +9,7 @@ from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.clear_error_support import assert_type from cuda.core.experimental._utils.cuda_utils import ( - _reduce_tuple, + _reduce_3_tuple, check_or_create_options, driver, get_binding_version, @@ -95,9 +95,9 @@ def _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stre dev = stream.device num_sm = dev.properties.multiprocessor_count max_grid_size = ( - kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_tuple(config.block), config.shmem_size) * num_sm + kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm ) - if _reduce_tuple(config.grid) > max_grid_size: + if _reduce_3_tuple(config.grid) > max_grid_size: # For now let's try not to be smart and adjust the grid size behind users' back. # We explicitly ask users to adjust. x, y, z = config.grid diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py b/cuda_core/cuda/core/experimental/_utils/cuda_utils.py index 0c8635b6e..48b48d2fb 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.py +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.py @@ -48,8 +48,8 @@ def cast_to_3_tuple(label, cfg): return cfg + (1,) * (3 - len(cfg)) -def _reduce_tuple(t: tuple): - return functools.reduce(lambda x, y: x * y, t, 1) +def _reduce_3_tuple(t: tuple): + return t[0] * t[1] * t[2] def _check_driver_error(error): diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index 6cbcc3f06..45d8895a0 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -7,14 +7,12 @@ import numpy as np import pytest -from conftest import skipif_testing_with_compute_sanitizer +from conftest import skipif_need_cuda_headers, skipif_testing_with_compute_sanitizer import cuda.core.experimental from cuda.core.experimental import Device, EventOptions, LaunchConfig, Program, ProgramOptions, launch from cuda.core.experimental._memory import _DefaultPinnedMemorySource -from conftest import skipif_need_cuda_headers - def test_event_init_disabled(): with pytest.raises(RuntimeError, match=r"^Event objects cannot be instantiated directly\."): diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 62a0225e6..635d1fcf5 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -7,12 +7,11 @@ import numpy as np import pytest +from conftest import skipif_need_cuda_headers from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch from cuda.core.experimental._memory import _DefaultPinnedMemorySource -from conftest import skipif_need_cuda_headers - def test_launch_config_init(init_cuda): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=0)