Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ def __imod__(self, other: Array | float, /) -> Array:
"""
Performs the operation __imod__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
Expand All @@ -1126,6 +1127,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __imul__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
if other is NotImplemented:
return other
Expand All @@ -1148,6 +1150,7 @@ def __ior__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ior__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
if other is NotImplemented:
return other
Expand All @@ -1170,6 +1173,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __ipow__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
Expand All @@ -1182,6 +1186,7 @@ def __rpow__(self, other: Array | complex, /) -> Array:
"""
from ._elementwise_functions import pow # type: ignore[attr-defined]

self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
Expand All @@ -1193,6 +1198,7 @@ def __irshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __irshift__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
if other is NotImplemented:
return other
Expand All @@ -1215,6 +1221,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __isub__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
if other is NotImplemented:
return other
Expand All @@ -1237,6 +1244,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __itruediv__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
if other is NotImplemented:
return other
Expand All @@ -1259,6 +1267,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ixor__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
if other is NotImplemented:
return other
Expand Down
7 changes: 7 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ def _array_vals():
getattr(x, _op)(y)
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# finally, test that array op ndarray raises
# XXX: as long as there is __array__ or __buffer__, __rop__s
# still return ndarrays
if not _op.startswith("__r"):
with assert_raises(TypeError):
getattr(x, _op)(y._array)


for op, dtypes in unary_op_dtypes.items():
for a in _array_vals():
Expand Down
Loading