diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index fb2d6ab7..e2ee4fa0 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -11,8 +11,10 @@ from ._utils import _compat from ._utils._compat import ( array_namespace, + is_array_api_obj, is_dask_array, is_jax_array, + is_lazy_array, is_torch_array, is_writeable_array, ) @@ -73,7 +75,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 idx : index, optional Only `array API standard compliant indices `_ - are supported. + are supported. The only exception are one-dimensional integer array indices + (not expressed as tuples) along the first axis for set() operations. You may use two alternate syntaxes:: @@ -148,6 +151,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 >>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1) Array([125], dtype=int32) + The Array API standard does not support assignment by integer array, even if many + libraries like NumPy do. `xpx.at` works around lack of support by performing an + out-of-place operation. Assignments with multiple occurrences of the same index + always choose the last occurrence. This is consistent with NumPy's behaviour. + + >>> import numpy as np + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3])) + array([3]) + >>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3])) + Array([3], dtype=array_api_strict.int64) + See Also -------- jax.numpy.ndarray.at : Equivalent array method in JAX. @@ -355,9 +371,63 @@ def _op( # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) - else: # set() + return x + # set() + try: # We first try to use the backend's __setitem__ if available x[idx] = y - return x + return x + except IndexError as e: + if "Fancy indexing" not in str(e): # Avoid masking other index errors + raise e + # Work around lack of fancy indexing __setitem__ + if not ( + is_array_api_obj(idx) + and xp.isdtype(idx.dtype, "integral") + and idx.ndim == 1 + ): + raise + # Vectorize the operation using boolean indexing + # For non-unique indices, take the last occurrence. This requires + # masks for x and y that create matching shapes. + # We first create the mask for x + # Convert negative indices to positive, otherwise they won't get matched + idx = xp.where(idx < 0, x.shape[0] + idx, idx) + u_idx = xp.sort(xp.unique_values(idx)) + # Check for out of bounds indices + if not is_lazy_array(u_idx) and ( + xp.any(u_idx < 0) or xp.any(u_idx >= x.shape[0]) + ): + msg = f"index or indices out of bounds for array of shape {x.shape}" + raise IndexError(msg) from e + + # Construct a mask for x that is True where x's index is in u_idx. + # Equivalent to np.isin(). + x_rng = xp.arange(x.shape[0], device=_compat.device(u_idx)) + x_mask = xp.any(x_rng[..., None] == u_idx, axis=-1) + # If y is a scalar or 0D, we are done + if not is_array_api_obj(y) or y.ndim == 0: + x[x_mask] = y + return x + if y.shape[0] != idx.shape[0]: + msg = ( + f"shape mismatch: value array of shape {y.shape} could not be " + f"broadcast to indexing result of shape {idx.shape}" + ) + raise ValueError(msg) from e + # If not, create a mask for y. Get last occurrence of each unique index + cmp = idx[:, None] == idx[None, :] + total_matches = xp.sum(xp.astype(cmp, xp.int32), axis=-1) + # Ignore later matches + n = idx.shape[0] + lower_tri_mask = xp.tril(xp.ones((n, n), dtype=xp.bool)) + masked_cmp = cmp & lower_tri_mask + # For each position i, count how many matches occurred before i + prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) + # Last occurrence has highest match count + y_mask = prior_matches == total_matches + # Apply the operation only to last occurrences + x[x_mask] = y[y_mask] + return x def set( self, diff --git a/tests/test_at.py b/tests/test_at.py index 9558f7b8..1843f969 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -11,7 +11,13 @@ from array_api_extra._lib._at import _AtOp from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_equal -from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array +from array_api_extra._lib._utils._compat import ( + array_namespace, + is_array_api_strict_namespace, + is_jax_namespace, + is_numpy_namespace, + is_writeable_array, +) from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._typing import Array, Device, SetIndex from array_api_extra.testing import lazy_xp_function @@ -272,6 +278,64 @@ def test_bool_mask_nd(xp: ModuleType): xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]])) +def test_setitem_int_array_index(xp: ModuleType): + # Single dimension + x = xp.asarray([0.0, 1.0, 2.0]) + y = xp.asarray([3.0, 4.0]) + idx = xp.asarray([0, 2]) + expect = xp.asarray([3.0, 1.0, 4.0]) + z = at_op(x, idx, _AtOp.SET, y) + assert isinstance(z, type(x)) + xp_assert_equal(z, expect) + # Single dimension, non-unique index + x = xp.asarray([0.0, 1.0, 2.0]) + y = xp.asarray([3.0, 4.0, 5.0]) + idx = xp.asarray([0, 1, 0]) + device_str = str(get_device(x)).lower() + # GPU arrays generally use the first element, but JAX with float64 enabled uses the + # last element. + if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp): + expect = xp.asarray([3.0, 4.0, 2.0]) + else: + expect = xp.asarray([5.0, 4.0, 2.0]) # CPU arrays use the last + z = at_op(x, idx, _AtOp.SET, y) + assert isinstance(z, type(x)) + xp_assert_equal(z, expect) + # Multiple dimensions + x = xp.asarray([[0.0, 1.0], [2.0, 3.0]]) + y = xp.asarray([[4.0, 5.0]]) + idx = xp.asarray([0]) + expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]]) + z = at_op(x, idx, _AtOp.SET, y) + xp_assert_equal(z, expect) + # Scalar + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # 0D array + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0)) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # Negative indices + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # Different frameworks have all kinds of different behaviours for negative indices, + # out-of-bounds indices, etc. Therefore, we only test the behaviour of two + # frameworks: numpy because we state in the docs that it is our reference for the + # behaviour of other frameworks with no native support, and array-api-strict. + if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp): + # Test wrong shapes + with pytest.raises(ValueError, match="shape"): + _ = at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2])) + # Test positive out of bounds index + with pytest.raises(IndexError, match="out of bounds"): + _ = at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1])) + # Test negative out of bounds index + with pytest.raises(IndexError, match="out of bounds"): + _ = at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1])) + + @pytest.mark.parametrize("bool_mask", [False, True]) def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): x = xp.asarray([math.inf, 1.0, 2.0])