-
Notifications
You must be signed in to change notification settings - Fork 213
Closed
Labels
P0High priority - Must do!High priority - Must do!bugSomething isn't workingSomething isn't workingcuda.coreEverything related to the cuda.core moduleEverything related to the cuda.core module
Milestone
Description
@yangcal was trying the StridedMemoryView
but it doesn’t seem to work for jax array, at least with cuda12.2. Below is the script and the error log:
import cupy as cp
import numpy as np
import jax.numpy as jnp
from cuda.core.experimental.utils import args_viewable_as_strided_memory
@args_viewable_as_strided_memory((0,))
def parse_tensor(arr):
view = arr.view(-1)
print(type(arr), type(view))
print(f"shape={view.shape}")
print(f"strides={view.strides}")
for module in (np, cp, jnp):
arr = module.eye(2)
print(f"module={module.__name__}")
parse_tensor(arr)
Error Log:
E1211 13:52:21.674104 2256128 ptx_compiler_helpers.cc:71] *** WARNING *** Invoking ptxas with version 12.2.140, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
module=jax.numpy
Traceback (most recent call last):
File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 18, in <module>
parse_tensor(arr)
File "cuda/core/experimental/_memoryview.pyx", line 372, in cuda.core.experimental._memoryview.args_viewable_as_strided_memory.wrapped_func_with_indices.wrapped_func
File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 10, in parse_tensor
view = arr.view(-1)
^^^^^^^^^^^^
File "cuda/core/experimental/_memoryview.pyx", line 146, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
File "cuda/core/experimental/_memoryview.pyx", line 148, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
File "cuda/core/experimental/_memoryview.pyx", line 180, in cuda.core.experimental._memoryview.view_as_dlpack
File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/array.py", line 446, in __dlpack__
return to_dlpack(self, stream=stream,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 134, in to_dlpack
return _to_dlpack(
^^^^^^^^^^^
File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 65, in _to_dlpack
return xla_client._xla.buffer_to_dlpack_managed_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_INVALID_HANDLE: invalid resource handle
Is it a semantics error on my side or cuda version issue?
Metadata
Metadata
Assignees
Labels
P0High priority - Must do!High priority - Must do!bugSomething isn't workingSomething isn't workingcuda.coreEverything related to the cuda.core moduleEverything related to the cuda.core module