|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
18 | 18 | import operator
|
| 19 | +import sys |
19 | 20 | from collections.abc import Iterator
|
20 | 21 | from enum import IntEnum
|
21 | 22 | from types import EllipsisType, ModuleType
|
@@ -67,8 +68,6 @@ def __hash__(self) -> int:
|
67 | 68 | CPU_DEVICE = Device()
|
68 | 69 | ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))
|
69 | 70 |
|
70 |
| -_default = object() |
71 |
| - |
72 | 71 |
|
73 | 72 | class Array:
|
74 | 73 | """
|
@@ -149,29 +148,40 @@ def __repr__(self) -> str:
|
149 | 148 |
|
150 | 149 | __str__ = __repr__
|
151 | 150 |
|
152 |
| - # `__array__` was implemented historically for compatibility, and removing it has |
153 |
| - # caused issues for some libraries (see |
154 |
| - # https://github.com/data-apis/array-api-strict/issues/67). |
155 |
| - |
156 |
| - # Instead of `__array__` we now implement the buffer protocol. |
157 |
| - # Note that it makes array-apis-strict requiring python>=3.12 |
158 | 151 | def __buffer__(self, flags):
|
159 | 152 | if self._device != CPU_DEVICE:
|
160 |
| - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") |
| 153 | + raise RuntimeError( |
| 154 | + # NumPy swallows this exception and falls back to __array__. |
| 155 | + f"Can't extract host buffer from array on the '{self._device}' device." |
| 156 | + ) |
161 | 157 | return self._array.__buffer__(flags)
|
162 | 158 |
|
163 |
| - # We do not define __release_buffer__, per the discussion at |
164 |
| - # https://github.com/data-apis/array-api-strict/pull/115#pullrequestreview-2917178729 |
165 |
| - |
166 |
| - def __array__(self, *args, **kwds): |
167 |
| - # a stub for python < 3.12; otherwise numpy silently produces object arrays |
168 |
| - import sys |
169 |
| - minor, major = sys.version_info.minor, sys.version_info.major |
170 |
| - if major < 3 or minor < 12: |
| 159 | + # `__array__` is not part of the Array API. Ideally we want to support |
| 160 | + # `xp.asarray(Array)` exclusively through the __buffer__ protocol; however this is |
| 161 | + # only possible on Python >=3.12. Additionally, when __buffer__ raises (e.g. because |
| 162 | + # the array is not on the CPU device, NumPy will try to fall back on __array__ but, |
| 163 | + # if that doesn't exist, create a scalar numpy array of objects which contains the |
| 164 | + # array_api_strict.Array. So we can't get rid of __array__ entirely. |
| 165 | + def __array__( |
| 166 | + self, dtype: None | np.dtype[Any] = None, copy: None | bool = None |
| 167 | + ) -> npt.NDArray[Any]: |
| 168 | + if self._device != CPU_DEVICE: |
| 169 | + # We arrive here from np.asarray() on Python >=3.12 when __buffer__ raises. |
| 170 | + raise RuntimeError( |
| 171 | + f"Can't convert array on the '{self._device}' device to a " |
| 172 | + "NumPy array." |
| 173 | + ) |
| 174 | + if sys.version_info >= (3, 12): |
171 | 175 | raise TypeError(
|
172 |
| - "Interoperation with NumPy requires python >= 3.12. Please upgrade." |
| 176 | + "The __array__ method is not supported by the Array API. " |
| 177 | + "Please use the __buffer__ interface instead." |
173 | 178 | )
|
174 | 179 |
|
| 180 | + # copy keyword is new in 2.0 |
| 181 | + if np.__version__[0] < '2': |
| 182 | + return np.asarray(self._array, dtype=dtype) |
| 183 | + return np.asarray(self._array, dtype=dtype, copy=copy) |
| 184 | + |
175 | 185 | # These are various helper functions to make the array behavior match the
|
176 | 186 | # spec in places where it either deviates from or is more strict than
|
177 | 187 | # NumPy behavior
|
|
0 commit comments