Skip to content

switch the launch argument order #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ class LaunchConfig:
Group of threads (Thread Block) that will execute on the same
streaming multiprocessor (SM). Threads within a thread blocks have
access to shared memory and can be explicitly synchronized.
stream : :obj:`~_stream.Stream`
The stream establishing the stream ordering semantic of a
launch.
shmem_size : int, optional
Dynamic shared-memory size per thread block in bytes.
(Default to size 0)
Expand All @@ -58,7 +55,6 @@ class LaunchConfig:
grid: Union[tuple, int] = None
cluster: Union[tuple, int] = None
block: Union[tuple, int] = None
stream: Stream = None
shmem_size: Optional[int] = None

def __post_init__(self):
Expand All @@ -72,12 +68,6 @@ def __post_init__(self):
if Device().compute_capability < (9, 0):
raise CUDAError("thread block clusters are not supported on devices with compute capability < 9.0")
self.cluster = self._cast_to_3_tuple(self.cluster)
# we handle "stream=None" in the launch API
if self.stream is not None and not isinstance(self.stream, Stream):
try:
self.stream = Stream._init(self.stream)
except Exception as e:
raise ValueError("stream must either be a Stream object or support __cuda_stream__") from e
if self.shmem_size is None:
self.shmem_size = 0

Expand Down Expand Up @@ -105,27 +95,35 @@ def _cast_to_3_tuple(self, cfg):
raise ValueError


def launch(kernel, config, *kernel_args):
def launch(stream, config, kernel, *kernel_args):
"""Launches a :obj:`~_module.Kernel`
object with launch-time configuration.

Parameters
----------
kernel : :obj:`~_module.Kernel`
Kernel to launch.
stream : :obj:`~_stream.Stream`
The stream establishing the stream ordering semantic of a
launch.
config : :obj:`~_launcher.LaunchConfig`
Launch configurations inline with options provided by
:obj:`~_launcher.LaunchConfig` dataclass.
kernel : :obj:`~_module.Kernel`
Kernel to launch.
*kernel_args : Any
Variable length argument list that is provided to the
launching kernel.

"""
if stream is None:
raise ValueError("stream cannot be None, stream must either be a Stream object or support __cuda_stream__")
if not isinstance(stream, Stream):
try:
stream = Stream._init(stream)
except Exception as e:
raise ValueError("stream must either be a Stream object or support __cuda_stream__") from e
if not isinstance(kernel, Kernel):
raise ValueError
config = check_or_create_options(LaunchConfig, config, "launch config")
if config.stream is None:
raise CUDAError("stream cannot be None")

# TODO: can we ensure kernel_args is valid/safe to use here?
# TODO: merge with HelperKernelParams?
Expand All @@ -141,7 +139,7 @@ def launch(kernel, config, *kernel_args):
drv_cfg = driver.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
drv_cfg.hStream = config.stream.handle
drv_cfg.hStream = stream.handle
drv_cfg.sharedMemBytes = config.shmem_size
attrs = [] # TODO: support more attributes
if config.cluster:
Expand All @@ -157,6 +155,6 @@ def launch(kernel, config, *kernel_args):
# TODO: check if config has any unsupported attrs
handle_return(
driver.cuLaunchKernel(
int(kernel._handle), *config.grid, *config.block, config.shmem_size, config.stream._handle, args_ptr, 0
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream._handle, args_ptr, 0
)
)
2 changes: 2 additions & 0 deletions cuda_core/docs/source/release/0.2.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Highlights
Breaking Changes
----------------

- The ``stream`` attribute is removed from :class:`~LaunchConfig`. Instead, the :class:`Stream` object should now be directly passed to :func:`~launch` as an argument.
- The signature for :func:`~launch` is changed by swapping positional arguments, the new signature is now ``(stream, config, kernel, *kernel_args)``
- Change ``__cuda_stream__`` from attribute to method.
- The :meth:`Program.compile` method no longer accepts the ``options`` argument. Instead, you can optionally pass an instance of :class:`ProgramOptions` to the constructor of :class:`Program`.
- :meth:`Device.properties` now provides attribute getters instead of a dictionary interface.
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/examples/jit_lto_fractal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self):
# the problem into 16x16 chunks.
self.grid = (self.width / 16, self.height / 16, 1.0)
self.block = (16, 16, 1)
self.config = LaunchConfig(grid=self.grid, block=self.block, stream=self.stream)
self.config = LaunchConfig(grid=self.grid, block=self.block)

def link(self, user_code, target_type):
if target_type == "ltoir":
Expand All @@ -103,7 +103,7 @@ def link(self, user_code, target_type):
return linked_code.get_kernel("main_workflow")

def run(self, kernel):
launch(kernel, self.config, self.buffer.data.ptr)
launch(self.stream, self.config, kernel, self.buffer.data.ptr)
self.stream.sync()

# Return the result as a NumPy array (on host).
Expand Down
8 changes: 4 additions & 4 deletions cuda_core/examples/saxpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
# prepare launch
block = 32
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block, stream=s)
config = LaunchConfig(grid=grid, block=block)
ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size)

# launch kernel on stream s
launch(ker, config, *ker_args)
launch(s, config, ker, *ker_args)
s.sync()

# check result
Expand All @@ -85,11 +85,11 @@
# prepare launch
block = 64
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block, stream=s)
config = LaunchConfig(grid=grid, block=block)
ker_args = (a, x.data.ptr, y.data.ptr, buf, size)

# launch kernel on stream s
launch(ker, config, *ker_args)
launch(s, config, ker, *ker_args)
s.sync()

# check result
Expand Down
8 changes: 4 additions & 4 deletions cuda_core/examples/simple_multi_gpu_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def __cuda_stream__(self):
# CUDA streams.
block = 256
grid = (size + block - 1) // block
config0 = LaunchConfig(grid=grid, block=block, stream=stream0)
config1 = LaunchConfig(grid=grid, block=block, stream=stream1)
config0 = LaunchConfig(grid=grid, block=block)
config1 = LaunchConfig(grid=grid, block=block)

# Allocate memory on GPU 0
# Note: This runs on CuPy's current stream for GPU 0
Expand All @@ -94,7 +94,7 @@ def __cuda_stream__(self):
stream0.wait(cp_stream0)

# Launch the add kernel on GPU 0 / stream 0
launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
launch(stream0, config0, ker_add, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))

# Allocate memory on GPU 1
# Note: This runs on CuPy's current stream for GPU 1.
Expand All @@ -108,7 +108,7 @@ def __cuda_stream__(self):
stream1.wait(cp_stream1)

# Launch the subtract kernel on GPU 1 / stream 1
launch(ker_sub, config1, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size))
launch(stream1, config1, ker_sub, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size))

# Synchronize both GPUs are validate the results
dev0.set_current()
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/examples/strided_memory_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def my_func(arr, work_stream):
if view.is_device_accessible:
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block, stream=work_stream)
launch(gpu_ker, config, view.ptr, np.uint64(size))
config = LaunchConfig(grid=grid, block=block)
launch(work_stream, config, gpu_ker, view.ptr, np.uint64(size))
# Here we're being conservative and synchronize over our work stream,
# assuming we do not know the data stream; if we know then we could
# just order the data stream after the work stream here, e.g.
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/examples/thread_block_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
grid = 4
cluster = 2
block = 32
config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream)
config = LaunchConfig(grid=grid, cluster=cluster, block=block)

# launch kernel on the default stream
launch(ker, config)
launch(dev.default_stream, config, ker)
dev.sync()

print("done!")
4 changes: 2 additions & 2 deletions cuda_core/examples/vector_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
# prepare launch
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block, stream=s)
config = LaunchConfig(grid=grid, block=block)

# launch kernel on stream s
launch(ker, config, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
launch(s, config, ker, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
s.sync()

# check result
Expand Down
38 changes: 24 additions & 14 deletions cuda_core/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@

import pytest

from cuda.core.experimental import Device, LaunchConfig, Stream
from cuda.core.experimental import Device, LaunchConfig, Program, launch


def test_launch_config_init(init_cuda):
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0)
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=0)
assert config.grid == (1, 1, 1)
assert config.block == (1, 1, 1)
assert config.stream is None
assert config.shmem_size == 0

config = LaunchConfig(grid=(2, 2, 2), block=(2, 2, 2), stream=Device().create_stream(), shmem_size=1024)
config = LaunchConfig(grid=(2, 2, 2), block=(2, 2, 2), shmem_size=1024)
assert config.grid == (2, 2, 2)
assert config.block == (2, 2, 2)
assert isinstance(config.stream, Stream)
assert config.shmem_size == 1024


Expand Down Expand Up @@ -51,18 +49,30 @@ def test_launch_config_invalid_values():
LaunchConfig(grid=(1, 1, 1), block=(0, 1))


def test_launch_config_stream(init_cuda):
def test_launch_config_shmem_size():
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=2048)
assert config.shmem_size == 2048

config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1))
assert config.shmem_size == 0


def test_launch_invalid_values(init_cuda):
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
mod = program.compile("cubin")

stream = Device().create_stream()
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0)
assert config.stream == stream
ker = mod.get_kernel("my_kernel")
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=0)

with pytest.raises(ValueError):
LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream="invalid_stream", shmem_size=0)
launch(None, ker, config)

with pytest.raises(ValueError):
launch(stream, None, config)

def test_launch_config_shmem_size():
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=2048)
assert config.shmem_size == 2048
with pytest.raises(ValueError):
launch(stream, ker, None)

config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None)
assert config.shmem_size == 0
launch(stream, config, ker)
Loading