Skip to content

StridedMemoryView fails with Jax arrays #285

@leofang

Description

@leofang

@yangcal was trying the StridedMemoryView but it doesn’t seem to work for jax array, at least with cuda12.2. Below is the script and the error log:

import cupy as cp
import numpy as np
import jax.numpy as jnp

from cuda.core.experimental.utils import args_viewable_as_strided_memory

@args_viewable_as_strided_memory((0,))
def parse_tensor(arr):
    view = arr.view(-1)
    print(type(arr), type(view))
    print(f"shape={view.shape}")
    print(f"strides={view.strides}")

for module in (np, cp, jnp):
    arr = module.eye(2)
    print(f"module={module.__name__}")
    parse_tensor(arr)

Error Log:

E1211 13:52:21.674104 2256128 ptx_compiler_helpers.cc:71] *** WARNING *** Invoking ptxas with version 12.2.140, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
module=jax.numpy
Traceback (most recent call last):
  File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 18, in <module>
    parse_tensor(arr)
  File "cuda/core/experimental/_memoryview.pyx", line 372, in cuda.core.experimental._memoryview.args_viewable_as_strided_memory.wrapped_func_with_indices.wrapped_func
  File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 10, in parse_tensor
    view = arr.view(-1)
           ^^^^^^^^^^^^
  File "cuda/core/experimental/_memoryview.pyx", line 146, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
  File "cuda/core/experimental/_memoryview.pyx", line 148, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
  File "cuda/core/experimental/_memoryview.pyx", line 180, in cuda.core.experimental._memoryview.view_as_dlpack
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/array.py", line 446, in __dlpack__
    return to_dlpack(self, stream=stream,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 134, in to_dlpack
    return _to_dlpack(
           ^^^^^^^^^^^
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 65, in _to_dlpack
    return xla_client._xla.buffer_to_dlpack_managed_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_INVALID_HANDLE: invalid resource handle

Is it a semantics error on my side or cuda version issue?

Metadata

Metadata

Assignees

Labels

P0High priority - Must do!bugSomething isn't workingcuda.coreEverything related to the cuda.core module

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions