Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_array_api_obj,
is_dask_array,
is_jax_array,
is_lazy_array,
is_torch_array,
is_writeable_array,
)
Expand Down Expand Up @@ -73,7 +75,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
idx : index, optional
Only `array API standard compliant indices
<https://data-apis.org/array-api/latest/API_specification/indexing.html>`_
are supported.
are supported. The only exception are one-dimensional integer array indices
(not expressed as tuples) along the first axis for set() operations.

You may use two alternate syntaxes::

Expand Down Expand Up @@ -148,6 +151,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)

The Array API standard does not support assignment by integer array, even if many
libraries like NumPy do. `xpx.at` works around lack of support by performing an
out-of-place operation. Assignments with multiple occurrences of the same index
always choose the last occurrence. This is consistent with NumPy's behaviour.

>>> import numpy as np
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
array([3])
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
Array([3], dtype=array_api_strict.int64)
Comment on lines +159 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example leaves me confused. I don't think it adds anything?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see

>>> xpx.at(torch.tensor([0]).cuda(), torch.tensor([0, 0]).cuda()).set(torch.tensor([2, 3]).cuda())
torch.Tensor([2], dtype=torch.int64)


See Also
--------
jax.numpy.ndarray.at : Equivalent array method in JAX.
Expand Down Expand Up @@ -355,9 +371,63 @@ def _op(
# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These remain broken.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure if we should attempt to fix them. See my general comment.

else: # set()
return x
# set()
try: # We first try to use the backend's __setitem__ if available
x[idx] = y
return x
return x
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
Comment on lines +379 to +381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.

Suggested change
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
except Exception as e:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing.

# Work around lack of fancy indexing __setitem__
if not (
is_array_api_obj(idx)
and xp.isdtype(idx.dtype, "integral")
and idx.ndim == 1
):
raise
# Vectorize the operation using boolean indexing
# For non-unique indices, take the last occurrence. This requires
# masks for x and y that create matching shapes.
# We first create the mask for x
# Convert negative indices to positive, otherwise they won't get matched
idx = xp.where(idx < 0, x.shape[0] + idx, idx)
u_idx = xp.sort(xp.unique_values(idx))
# Check for out of bounds indices
if not is_lazy_array(u_idx) and (
xp.any(u_idx < 0) or xp.any(u_idx >= x.shape[0])
):
msg = f"index or indices out of bounds for array of shape {x.shape}"
raise IndexError(msg) from e

# Construct a mask for x that is True where x's index is in u_idx.
# Equivalent to np.isin().
x_rng = xp.arange(x.shape[0], device=_compat.device(u_idx))
x_mask = xp.any(x_rng[..., None] == u_idx, axis=-1)
# If y is a scalar or 0D, we are done
if not is_array_api_obj(y) or y.ndim == 0:
x[x_mask] = y
return x
if y.shape[0] != idx.shape[0]:
msg = (
f"shape mismatch: value array of shape {y.shape} could not be "
f"broadcast to indexing result of shape {idx.shape}"
)
raise ValueError(msg) from e
# If not, create a mask for y. Get last occurrence of each unique index
cmp = idx[:, None] == idx[None, :]
total_matches = xp.sum(xp.astype(cmp, xp.int32), axis=-1)
# Ignore later matches
n = idx.shape[0]
lower_tri_mask = xp.tril(xp.ones((n, n), dtype=xp.bool))
masked_cmp = cmp & lower_tri_mask
# For each position i, count how many matches occurred before i
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
# Last occurrence has highest match count
y_mask = prior_matches == total_matches
# Apply the operation only to last occurrences
x[x_mask] = y[y_mask]
return x

def set(
self,
Expand Down
66 changes: 65 additions & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._compat import (
array_namespace,
is_array_api_strict_namespace,
is_jax_namespace,
is_numpy_namespace,
is_writeable_array,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -272,6 +278,64 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


def test_setitem_int_array_index(xp: ModuleType):
# Single dimension
x = xp.asarray([0.0, 1.0, 2.0])
y = xp.asarray([3.0, 4.0])
idx = xp.asarray([0, 2])
expect = xp.asarray([3.0, 1.0, 4.0])
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Single dimension, non-unique index
x = xp.asarray([0.0, 1.0, 2.0])
y = xp.asarray([3.0, 4.0, 5.0])
idx = xp.asarray([0, 1, 0])
device_str = str(get_device(x)).lower()
# GPU arrays generally use the first element, but JAX with float64 enabled uses the
# last element.
if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp):
expect = xp.asarray([3.0, 4.0, 2.0])
else:
expect = xp.asarray([5.0, 4.0, 2.0]) # CPU arrays use the last
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Multiple dimensions
x = xp.asarray([[0.0, 1.0], [2.0, 3.0]])
y = xp.asarray([[4.0, 5.0]])
idx = xp.asarray([0])
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
z = at_op(x, idx, _AtOp.SET, y)
xp_assert_equal(z, expect)
# Scalar
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# 0D array
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0))
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# Negative indices
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# Different frameworks have all kinds of different behaviours for negative indices,
# out-of-bounds indices, etc. Therefore, we only test the behaviour of two
# frameworks: numpy because we state in the docs that it is our reference for the
# behaviour of other frameworks with no native support, and array-api-strict.
if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp):
# Test wrong shapes
with pytest.raises(ValueError, match="shape"):
_ = at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2]))
# Test positive out of bounds index
with pytest.raises(IndexError, match="out of bounds"):
_ = at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1]))
# Test negative out of bounds index
with pytest.raises(IndexError, match="out of bounds"):
_ = at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1]))


@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
Expand Down