diff --git a/cuda_core/examples/strided_memory_view_cpu.py b/cuda_core/examples/strided_memory_view_cpu.py new file mode 100644 index 000000000..0fa8f38e6 --- /dev/null +++ b/cuda_core/examples/strided_memory_view_cpu.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + +# ################################################################################ +# +# This demo aims to illustrate two takeaways: +# +# 1. The similarity between CPU and GPU JIT-compilation with C++ sources +# 2. How to use StridedMemoryView to interface with foreign C/C++ functions +# +# To facilitate this demo, we use cffi (https://cffi.readthedocs.io/) for the CPU +# path, which can be easily installed from pip or conda following their instructions. +# We also use NumPy/CuPy as the CPU/GPU array container. +# +# ################################################################################ + +import importlib +import shutil +import string +import sys +import tempfile + +try: + from cffi import FFI +except ImportError: + print("cffi is not installed, the CPU example will be skipped", file=sys.stderr) + FFI = None +import numpy as np + +from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory + +# ################################################################################ +# +# Usually this entire code block is in a separate file, built as a Python extension +# module that can be imported by users at run time. For illustrative purposes we +# use JIT compilation to make this demo self-contained. +# +# Here we assume an in-place operation, equivalent to the following NumPy code: +# +# >>> arr = ... +# >>> assert arr.dtype == np.int32 +# >>> assert arr.ndim == 1 +# >>> arr += np.arange(arr.size, dtype=arr.dtype) +# +# is implemented for both CPU and GPU at low-level, with the following C function +# signature: +func_name = "inplace_plus_arange_N" +func_sig = f"void {func_name}(int* data, size_t N)" + + +# Now we are prepared to run the code from the user's perspective! +# +# ################################################################################ + + +# Below, as a user we want to perform the said in-place operation on a CPU +# or GPU, by calling the corresponding function implemented "elsewhere" +# (in the body of run function). + + +# We assume the 0-th argument supports either DLPack or CUDA Array Interface (both +# of which are supported by StridedMemoryView). +@args_viewable_as_strided_memory((0,)) +def my_func(arr): + global cpu_func + global cpu_prog + # Create a memory view over arr (assumed to be a 1D array of int32). The stream + # ordering is taken care of, so that arr can be safely accessed on our work + # stream (ordered after a data stream on which arr is potentially prepared). + view = arr.view(-1) + assert isinstance(view, StridedMemoryView) + assert len(view.shape) == 1 + assert view.dtype == np.int32 + assert not view.is_device_accessible + + size = view.shape[0] + # DLPack also supports host arrays. We want to know if the array data is + # accessible from the GPU, and dispatch to the right routine accordingly. + cpu_func(cpu_prog.cast("int*", view.ptr), size) + + +def run(): + global my_func + if not FFI: + return + # Here is a concrete (very naive!) implementation on CPU: + cpu_code = string.Template(r""" + extern "C" + $func_sig { + for (size_t i = 0; i < N; i++) { + data[i] += i; + } + } + """).substitute(func_sig=func_sig) + # This is cffi's way of JIT compiling & loading a CPU function. cffi builds an + # extension module that has the Python binding to the underlying C function. + # For more details, please refer to cffi's documentation. + cpu_prog = FFI() + cpu_prog.cdef(f"{func_sig};") + cpu_prog.set_source( + "_cpu_obj", + cpu_code, + source_extension=".cpp", + extra_compile_args=["-std=c++11"], + ) + temp_dir = tempfile.mkdtemp() + saved_sys_path = sys.path.copy() + try: + cpu_prog.compile(tmpdir=temp_dir) + + sys.path.append(temp_dir) + cpu_func = getattr(importlib.import_module("_cpu_obj.lib"), func_name) + + # Create input array on CPU + arr_cpu = np.zeros(1024, dtype=np.int32) + print(f"before: {arr_cpu[:10]=}") + + # Run the workload + my_func(arr_cpu) + + # Check the result + print(f"after: {arr_cpu[:10]=}") + assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32)) + finally: + sys.path = saved_sys_path + # to allow FFI module to unload, we delete references to + # to cpu_func + del cpu_func, my_func + # clean up temp directory + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + run() diff --git a/cuda_core/examples/strided_memory_view.py b/cuda_core/examples/strided_memory_view_gpu.py similarity index 60% rename from cuda_core/examples/strided_memory_view.py rename to cuda_core/examples/strided_memory_view_gpu.py index a145f7a4e..10d12fd30 100644 --- a/cuda_core/examples/strided_memory_view.py +++ b/cuda_core/examples/strided_memory_view_gpu.py @@ -15,15 +15,9 @@ # # ################################################################################ -import importlib import string import sys -try: - from cffi import FFI -except ImportError: - print("cffi is not installed, the CPU example will be skipped", file=sys.stderr) - FFI = None try: import cupy as cp except ImportError: @@ -52,51 +46,6 @@ func_name = "inplace_plus_arange_N" func_sig = f"void {func_name}(int* data, size_t N)" -# Here is a concrete (very naive!) implementation on CPU: -if FFI: - cpu_code = string.Template(r""" - extern "C" - $func_sig { - for (size_t i = 0; i < N; i++) { - data[i] += i; - } - } - """).substitute(func_sig=func_sig) - # This is cffi's way of JIT compiling & loading a CPU function. cffi builds an - # extension module that has the Python binding to the underlying C function. - # For more details, please refer to cffi's documentation. - cpu_prog = FFI() - cpu_prog.cdef(f"{func_sig};") - cpu_prog.set_source( - "_cpu_obj", - cpu_code, - source_extension=".cpp", - extra_compile_args=["-std=c++11"], - ) - cpu_prog.compile() - cpu_func = getattr(importlib.import_module("_cpu_obj.lib"), func_name) - -# Here is a concrete (again, very naive!) implementation on GPU: -if cp: - gpu_code = string.Template(r""" - extern "C" - __global__ $func_sig { - const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - const size_t stride_size = gridDim.x * blockDim.x; - for (size_t i = tid; i < N; i += stride_size) { - data[i] += i; - } - } - """).substitute(func_sig=func_sig) - - # To know the GPU's compute capability, we need to identify which GPU to use. - dev = Device(0) - dev.set_current() - arch = "".join(f"{i}" for i in dev.compute_capability) - gpu_prog = Program(gpu_code, code_type="c++", options=ProgramOptions(arch=f"sm_{arch}", std="c++11")) - mod = gpu_prog.compile(target_type="cubin") - gpu_ker = mod.get_kernel(func_name) - # Now we are prepared to run the code from the user's perspective! # # ################################################################################ @@ -109,7 +58,7 @@ # We assume the 0-th argument supports either DLPack or CUDA Array Interface (both # of which are supported by StridedMemoryView). @args_viewable_as_strided_memory((0,)) -def my_func(arr, work_stream): +def my_func(arr, work_stream, gpu_ker): # Create a memory view over arr (assumed to be a 1D array of int32). The stream # ordering is taken care of, so that arr can be safely accessed on our work # stream (ordered after a data stream on which arr is potentially prepared). @@ -117,52 +66,64 @@ def my_func(arr, work_stream): assert isinstance(view, StridedMemoryView) assert len(view.shape) == 1 assert view.dtype == np.int32 + assert view.is_device_accessible size = view.shape[0] # DLPack also supports host arrays. We want to know if the array data is # accessible from the GPU, and dispatch to the right routine accordingly. - if view.is_device_accessible: - block = 256 - grid = (size + block - 1) // block - config = LaunchConfig(grid=grid, block=block) - launch(work_stream, config, gpu_ker, view.ptr, np.uint64(size)) - # Here we're being conservative and synchronize over our work stream, - # assuming we do not know the data stream; if we know then we could - # just order the data stream after the work stream here, e.g. - # - # data_stream.wait(work_stream) - # - # without an expensive synchronization (with respect to the host). - work_stream.sync() - else: - cpu_func(cpu_prog.cast("int*", view.ptr), size) - - -# This takes the CPU path -if FFI: - # Create input array on CPU - arr_cpu = np.zeros(1024, dtype=np.int32) - print(f"before: {arr_cpu[:10]=}") - - # Run the workload - my_func(arr_cpu, None) - - # Check the result - print(f"after: {arr_cpu[:10]=}") - assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32)) - - -# This takes the GPU path -if cp: + block = 256 + grid = (size + block - 1) // block + config = LaunchConfig(grid=grid, block=block) + launch(work_stream, config, gpu_ker, view.ptr, np.uint64(size)) + # Here we're being conservative and synchronize over our work stream, + # assuming we do not know the data stream; if we know then we could + # just order the data stream after the work stream here, e.g. + # + # data_stream.wait(work_stream) + # + # without an expensive synchronization (with respect to the host). + work_stream.sync() + + +def run(): + global my_func + if not cp: + return None + # Here is a concrete (very naive!) implementation on GPU: + gpu_code = string.Template(r""" + extern "C" + __global__ $func_sig { + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t stride_size = gridDim.x * blockDim.x; + for (size_t i = tid; i < N; i += stride_size) { + data[i] += i; + } + } + """).substitute(func_sig=func_sig) + + # To know the GPU's compute capability, we need to identify which GPU to use. + dev = Device(0) + dev.set_current() + arch = "".join(f"{i}" for i in dev.compute_capability) + gpu_prog = Program(gpu_code, code_type="c++", options=ProgramOptions(arch=f"sm_{arch}", std="c++11")) + mod = gpu_prog.compile(target_type="cubin") + gpu_ker = mod.get_kernel(func_name) + s = dev.create_stream() - # Create input array on GPU - arr_gpu = cp.ones(1024, dtype=cp.int32) - print(f"before: {arr_gpu[:10]=}") + try: + # Create input array on GPU + arr_gpu = cp.ones(1024, dtype=cp.int32) + print(f"before: {arr_gpu[:10]=}") + + # Run the workload + my_func(arr_gpu, s, gpu_ker) + + # Check the result + print(f"after: {arr_gpu[:10]=}") + assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32)) + finally: + s.close() - # Run the workload - my_func(arr_gpu, s) - # Check the result - print(f"after: {arr_gpu[:10]=}") - assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32)) - s.close() +if __name__ == "__main__": + run() diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 5bd5b52b7..4d615d73d 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -1,9 +1,7 @@ # Copyright 2024 NVIDIA Corporation. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import glob import os -import sys try: from cuda.bindings import driver @@ -67,21 +65,6 @@ def pop_all_contexts(): return pop_all_contexts -# samples relying on cffi could fail as the modules cannot be imported -sys.path.append(os.getcwd()) - - -@pytest.fixture(scope="session", autouse=True) -def clean_up_cffi_files(): - yield - files = glob.glob(os.path.join(os.getcwd(), "_cpu_obj*")) - for f in files: - try: # noqa: SIM105 - os.remove(f) - except FileNotFoundError: - pass # noqa: SIM105 - - skipif_testing_with_compute_sanitizer = pytest.mark.skipif( os.environ.get("CUDA_PYTHON_TESTING_WITH_COMPUTE_SANITIZER", "0") == "1", reason="The compute-sanitizer is running, and this test causes an API error.",