diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 041602691..33827fcef 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -224,10 +224,9 @@ class ObjectCode: Note ---- This class has no default constructor. If you already have a cubin that you would - like to load, use the :meth:`from_cubin` alternative constructor. For all other - possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program` - accepts them and returns an :class:`ObjectCode` instance with its - :meth:`~cuda.core.experimental.Program.compile` method. + like to load, use the :meth:`from_cubin` alternative constructor. Constructing directly + from all other possible code types should be avoided in favor of compilation through + :class:`~cuda.core.experimental.Program` Note ---- @@ -278,6 +277,22 @@ def from_cubin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = No """ return ObjectCode._init(module, "cubin", symbol_mapping=symbol_mapping) + @staticmethod + def from_ptx(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode": + """Create an :class:`ObjectCode` instance from an existing PTX. + + Parameters + ---------- + module : Union[bytes, str] + Either a bytes object containing the in-memory ptx code to load, or + a file path string pointing to the on-disk ptx file to load. + symbol_mapping : Optional[dict] + A dictionary specifying how the unmangled symbol names (as keys) + should be mapped to the mangled names before trying to retrieve + them (default to no mappings). + """ + return ObjectCode._init(module, "ptx", symbol_mapping=symbol_mapping) + # TODO: do we want to unload in a finalizer? Probably not.. def _lazy_load_module(self, *args, **kwargs): diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index 662add23c..8e7fea245 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -425,7 +425,8 @@ def close(self): self._linker.close() self._mnff.close() - def _can_load_generated_ptx(self): + @staticmethod + def _can_load_generated_ptx(): driver_ver = handle_return(driver.cuDriverGetVersion()) nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver diff --git a/cuda_core/tests/test_linker.py b/cuda_core/tests/test_linker.py index 59e8ca573..b6dd2a350 100644 --- a/cuda_core/tests/test_linker.py +++ b/cuda_core/tests/test_linker.py @@ -118,6 +118,14 @@ def test_linker_link_cubin(compile_ptx_functions): assert isinstance(linked_code, ObjectCode) +def test_linker_link_ptx_multiple(compile_ptx_functions): + ptxes = tuple(ObjectCode.from_ptx(obj.code) for obj in compile_ptx_functions) + options = LinkerOptions(arch=ARCH) + linker = Linker(*ptxes, options=options) + linked_code = linker.link("cubin") + assert isinstance(linked_code, ObjectCode) + + def test_linker_link_invalid_target_type(compile_ptx_functions): options = LinkerOptions(arch=ARCH) linker = Linker(*compile_ptx_functions, options=options) diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 8528c4d53..4bc7ceab3 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -13,25 +13,25 @@ from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system +SAXPY_KERNEL = """ +template +__global__ void saxpy(const T a, + const T* x, + const T* y, + T* out, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i - __global__ void saxpy(const T a, - const T* x, - const T* y, - T* out, - size_t N) { - const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (size_t i=tid; i", "saxpy"), @@ -41,6 +41,17 @@ def get_saxpy_kernel(init_cuda): return mod.get_kernel("saxpy"), mod +@pytest.fixture(scope="function") +def get_saxpy_kernel_ptx(init_cuda): + prog = Program(SAXPY_KERNEL, code_type="c++") + mod = prog.compile( + "ptx", + name_expressions=("saxpy", "saxpy"), + ) + ptx = mod._module + return ptx, mod + + def test_get_kernel(init_cuda): kernel = """extern "C" __global__ void ABC() { }""" @@ -100,6 +111,16 @@ def test_object_code_load_cubin(get_saxpy_kernel): mod.get_kernel("saxpy") # force loading +def test_object_code_load_ptx(get_saxpy_kernel_ptx): + ptx, mod = get_saxpy_kernel_ptx + sym_map = mod._sym_map + mod_obj = ObjectCode.from_ptx(ptx, symbol_mapping=sym_map) + assert mod.code == ptx + if not Program._can_load_generated_ptx(): + pytest.skip("PTX version too new for current driver") + mod_obj.get_kernel("saxpy") # force loading + + def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path): _, mod = get_saxpy_kernel cubin = mod._module