-
Notifications
You must be signed in to change notification settings - Fork 14
ENH: at
: add __setitem__
fancy indexing fallback
#395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7f42065
9df1c65
2c5e0aa
6fe2ffb
acd5365
07e6bda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,8 +11,10 @@ | |||||||||
from ._utils import _compat | ||||||||||
from ._utils._compat import ( | ||||||||||
array_namespace, | ||||||||||
is_array_api_obj, | ||||||||||
is_dask_array, | ||||||||||
is_jax_array, | ||||||||||
is_lazy_array, | ||||||||||
is_torch_array, | ||||||||||
is_writeable_array, | ||||||||||
) | ||||||||||
|
@@ -73,7 +75,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 | |||||||||
idx : index, optional | ||||||||||
Only `array API standard compliant indices | ||||||||||
<https://data-apis.org/array-api/latest/API_specification/indexing.html>`_ | ||||||||||
are supported. | ||||||||||
are supported. The only exception are one-dimensional integer array indices | ||||||||||
(not expressed as tuples) along the first axis for set() operations. | ||||||||||
|
||||||||||
You may use two alternate syntaxes:: | ||||||||||
|
||||||||||
|
@@ -148,6 +151,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 | |||||||||
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1) | ||||||||||
Array([125], dtype=int32) | ||||||||||
|
||||||||||
The Array API standard does not support assignment by integer array, even if many | ||||||||||
libraries like NumPy do. `xpx.at` works around lack of support by performing an | ||||||||||
out-of-place operation. Assignments with multiple occurrences of the same index | ||||||||||
always choose the last occurrence. This is consistent with NumPy's behaviour. | ||||||||||
|
||||||||||
>>> import numpy as np | ||||||||||
>>> import array_api_strict as xp | ||||||||||
>>> import array_api_extra as xpx | ||||||||||
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3])) | ||||||||||
array([3]) | ||||||||||
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3])) | ||||||||||
Array([3], dtype=array_api_strict.int64) | ||||||||||
|
||||||||||
See Also | ||||||||||
-------- | ||||||||||
jax.numpy.ndarray.at : Equivalent array method in JAX. | ||||||||||
|
@@ -355,9 +371,63 @@ def _op( | |||||||||
# Backends without boolean indexing (other than JAX) crash here | ||||||||||
if in_place_op: # add(), subtract(), ... | ||||||||||
x[idx] = in_place_op(x[idx], y) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These remain broken. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I'm not sure if we should attempt to fix them. See my general comment. |
||||||||||
else: # set() | ||||||||||
return x | ||||||||||
# set() | ||||||||||
try: # We first try to use the backend's __setitem__ if available | ||||||||||
x[idx] = y | ||||||||||
return x | ||||||||||
return x | ||||||||||
except IndexError as e: | ||||||||||
if "Fancy indexing" not in str(e): # Avoid masking other index errors | ||||||||||
raise e | ||||||||||
Comment on lines
+379
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing. |
||||||||||
# Work around lack of fancy indexing __setitem__ | ||||||||||
if not ( | ||||||||||
is_array_api_obj(idx) | ||||||||||
and xp.isdtype(idx.dtype, "integral") | ||||||||||
and idx.ndim == 1 | ||||||||||
): | ||||||||||
raise | ||||||||||
# Vectorize the operation using boolean indexing | ||||||||||
# For non-unique indices, take the last occurrence. This requires | ||||||||||
# masks for x and y that create matching shapes. | ||||||||||
# We first create the mask for x | ||||||||||
# Convert negative indices to positive, otherwise they won't get matched | ||||||||||
idx = xp.where(idx < 0, x.shape[0] + idx, idx) | ||||||||||
u_idx = xp.sort(xp.unique_values(idx)) | ||||||||||
# Check for out of bounds indices | ||||||||||
if not is_lazy_array(u_idx) and ( | ||||||||||
xp.any(u_idx < 0) or xp.any(u_idx >= x.shape[0]) | ||||||||||
): | ||||||||||
msg = f"index or indices out of bounds for array of shape {x.shape}" | ||||||||||
raise IndexError(msg) from e | ||||||||||
|
||||||||||
# Construct a mask for x that is True where x's index is in u_idx. | ||||||||||
# Equivalent to np.isin(). | ||||||||||
x_rng = xp.arange(x.shape[0], device=_compat.device(u_idx)) | ||||||||||
x_mask = xp.any(x_rng[..., None] == u_idx, axis=-1) | ||||||||||
# If y is a scalar or 0D, we are done | ||||||||||
if not is_array_api_obj(y) or y.ndim == 0: | ||||||||||
x[x_mask] = y | ||||||||||
return x | ||||||||||
if y.shape[0] != idx.shape[0]: | ||||||||||
msg = ( | ||||||||||
f"shape mismatch: value array of shape {y.shape} could not be " | ||||||||||
f"broadcast to indexing result of shape {idx.shape}" | ||||||||||
) | ||||||||||
raise ValueError(msg) from e | ||||||||||
# If not, create a mask for y. Get last occurrence of each unique index | ||||||||||
cmp = idx[:, None] == idx[None, :] | ||||||||||
total_matches = xp.sum(xp.astype(cmp, xp.int32), axis=-1) | ||||||||||
# Ignore later matches | ||||||||||
n = idx.shape[0] | ||||||||||
lower_tri_mask = xp.tril(xp.ones((n, n), dtype=xp.bool)) | ||||||||||
masked_cmp = cmp & lower_tri_mask | ||||||||||
# For each position i, count how many matches occurred before i | ||||||||||
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) | ||||||||||
# Last occurrence has highest match count | ||||||||||
y_mask = prior_matches == total_matches | ||||||||||
# Apply the operation only to last occurrences | ||||||||||
x[x_mask] = y[y_mask] | ||||||||||
return x | ||||||||||
|
||||||||||
def set( | ||||||||||
self, | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example leaves me confused. I don't think it adds anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see