data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
212 stars 44 forks source link

Proposal: add APIs for getting and setting elements via a list of indices (i.e., `take`, `put`, etc) #177

Open kgryte opened 3 years ago

kgryte commented 3 years ago

Proposal

Add APIs for getting and setting elements via a list of indices.

Motivation

Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit take and put APIs.

Two main arguments come to mind for supporting at least basic take and put APIs:

  1. Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.

  2. Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,

    • argsort returns an array of indices. In NumPy, the output of this function can be consumed by put_along_axis and take_along_axis.
    • unique can return an array of indices if return_index is True.

Background

The following table summarizes library implementations of such APIs:

op NumPy CuPy Dask MXNet Torch TensorFlow
extracting elements along axis take take take take take/gather gather/numpy.take
setting elements along axis put put -- -- scatter scatter_nd/tensor_scatter_nd_update
extracting elements over matching 1d slices take_along_axis take_along_axis -- -- -- gather_nd/numpy.take_alongaxis
setting elements over matching 1d slices put_along_axis -- -- -- -- --

While most libraries implement some form of take, fewer implement other complementary APIs.

rgommers commented 3 years ago

Thanks @kgryte. A few initial thoughts:

kgryte commented 3 years ago

@rgommers Thanks for the comments.

  1. Correct. I've updated the table with Torch and TF scatter and gather methods.
  2. Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.
  3. Not opposed to delaying until V2 (2022).
asmeurer commented 3 years ago

A natural question is if take is supported, is there any reason equivalent indexing shouldn't also be supported. Granted, take only represents a specific subset of general (NumPy) integer array indexing, where indexing is done on a single axis.

kgryte commented 3 years ago

@asmeurer I think that take would be an optional API; whereas indexing semantics should be universal.

rgommers commented 3 years ago

Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.

If the size of indices is variable, it's the function that produces indices that is data-dependent. take itself however is not. Compare with boolean indexing or nonzero, there the output size is in the range [0, x_input.size]; for take it's always x_input.size.

kgryte commented 3 years ago

@rgommers Correct; however, I still could imagine that data flows involving a take operation may still be problematic for AOT computational graphs. While the output size is indices.size, an array library may not be able to statically allocate memory for the output of the take operation. This said, accelerator libraries do manage to support similar APIs (e.g., scatter/gather), so probably no need to further belabor this.

kgryte commented 3 years ago

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

kgryte commented 3 years ago

Cross-linking to a discussion regarding issues concerning out-of-bounds access in take APIs for accelerator libraries.

thomasjpfan commented 2 years ago

In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset:

import numpy.array_api as xp

X = xp.asarray([[1, 2, 3, 4], [2, 3, 4, 5],
                [4, 5, 6, 10], [5, 6, 8, 20]], dtype=xp.float64)

sample_indices = xp.asarray([0, 0, 1, 3])

# Does not work
# X[sample_indices, :]

For libraries that need selection with integer arrays, a work around is to implement take:

def take(X, indices, *, axis):
    # Simple implementation that only works for axis in {0, 1}
    if axis == 0:
        selected = [X[i] for i in indices]
    else:  # axis == 1
        selected = [X[:, i] for i in indices]
    return xp.stack(selected, axis=axis)

take(X, sample_indices, axis=0)

Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice.

leofang commented 2 years ago

Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...)

shoyer commented 2 years ago

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases.

rgommers commented 2 years ago

Here is a summary of today's discussion:

Given all that, the proposal is to only add take now, and revisit integer array indexing and put in the future.

asmeurer commented 2 years ago

Something that I think was missed in today's discussion is that take and put aren't exactly the same as integer array indexing. Integer array indices operate on the axes of the array. take and put (at least in NumPy) operate on the flattened array.

>>> a = np.arange(9).reshape((3, 3)) + 10
>>> a[np.array([0, 2]), np.array([1, 2])]
array([11, 18])
>>> np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3))
array([1, 8])
>>> np.take(a, np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3)))
array([11, 18])

np.take also has an axis parameter but that's only equivalent to a single integer array index.

I'm not sure if there's an easy way within the array API to go from one to the other.

And I hope the the "integer array as the sole index" idea above was really meant to be "integer arrays as the sole indices". Just having a single integer array index means you can only index the first dimension of the array. This should also include integer scalars, as those are equivalent to 0-D arrays, unless we want to omit the "all integer array indices are broadcast together" rule.

I agree that NumPy's rules for mixing arrays with slices should not be included, especially the crazy rule about how it handles slices between integer array indices, which a design mistake in NumPy (slices around integer array indices isn't so bad, and can be useful, but also adds complexity to the indexing rules so I can see wanting to omit it).

shoyer commented 2 years ago

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of __getitem__ with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in __getitem__ and possibly also __setitem__ would the most useful functionality.

I would not be opposed to adding take if there is interest. It certainly is easier to construct calls to take, and knowing ahead of time that indexing is only going along a certain axis can sometimes allow for significant simplifications to indexing code. There are two alternatives we could consider for filling this same niche (easy integer based indexing along one dimension):

  1. Support for mixed array/slice indexing, like NumPy. But like Aaron says, this is too confusing for the API standard.
  2. We could include oindex, but this proposal never got entirely off the ground (beyond implementations in Xarray/Dask/Zarr).

If we do choose to include ake in the standard, the axis argument should be required. Slicing along flattened arrays is not very useful.

asmeurer commented 2 years ago

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of getitem with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in getitem and possibly also setitem would the most useful functionality.

The suggestion here is to support take but defer support for indexing. So users of the array API would need to rewrite usages of __getitem__ to take, not the other way around.

Slicing along flattened arrays is not very useful.

I've never really used take myself, so I don't have the best context here, but isn't the flattened behavior there to match put, which doesn't have axis?

shoyer commented 2 years ago

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

rgommers commented 2 years ago

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

shoyer commented 2 years ago

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

I think we could probably safely leave this as undefined behavior?

rgommers commented 2 years ago

I think we could probably safely leave this as undefined behavior?

Yes, fair enough.

Let me add another concern though, probably the main one (copied from higher up, with a minor edit: put --> __setitem__): Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like __setitem__.

I think my preferred order of doing things here would be:

  1. Add take with 1-D integer array indices now (see gh-416)
  2. Tighten up mutability specification
  3. Add __getitem__ and __setitem__ (with n-D integer inputs, assuming that behavior aligns across libraries).
arogozhnikov commented 2 years ago

Very interested in having at least a basic version of take to be incorporated into the standard.

Context: experimental version of more verbose indexing, see https://github.com/arogozhnikov/einops/issues/194 for details

rgommers commented 2 years ago

Thanks for the ping on this issue @arogozhnikov - and nice to see the experimental work on indexing in einops. I'd like to see gh-416 finished and merged in the coming days to indeed add take support with 1-D indices.

rgommers commented 1 year ago

Support for take has been merged, see gh-416.

lezcano commented 1 year ago

nit. we also have put_ in PyTorch (but not put...)

honno commented 1 year ago

I think if we were to introduce something like xp.put(x, indices, value) to the spec we can seemingly all agree on

  1. Only specifying a single array as the indices argument like we do with xp.take(), leaving other kinds of indices out-of-scope.

    e.g. for xp.put(x, indices, np.asarray([42, 7])) where x=xp.arange(5), indices=xp.asarray([1, 4]) would be supported, but the following equivalent arguments would be out-of-scope

    indices=(np.asarray(1), np.asarray(4))
    indices=(1, 4)
    indices=(1, np.asarray(4))
    indices=[1, 4]

    Array only indexing makes adoption easier and doesn't cause problems for accelerators.

  2. Keeping in-line with xp.take(), we should specify to only support indices as 1 dimensional.

    • Implicitly I'd be mandating that elements in indices relate to the index equivalent of the flattened equivalent of the input array, rather than any fancy broadcasting behaviours/etc..

      e.g. on the contrary, t.index_put_() use the shape of indices to specify multiple elements of the input array.

      >>> t = torch.as_tensor([[True, True]])
      >>> t.index_put_((torch.as_tensor([0]),), torch.as_tensor(False))
      tensor([[False, False]])
  3. Only specifying support for unique indices, e.g. indices=xp.asarray([0, 0]) would be out-of-scope. Consistent duplicate indices behaviour seems too niche and finicky to specify.

    • Interestingly PyTorch has the accumulate keyword for its t.put_()/t.index_put_() methods, where accumulate=False (default) leaves such behaviour unspecified, and accumulate=True puts the sum of the respective values.

The question areas then would be

  1. Should we support broadcasting value to the shape of indices?

    • np.put() broadcasts(?) the value to the indices, e.g.

      >>> a = np.asarray([True, True])
      >>> np.put(a, np.asarray([0, 1]), np.asarray(False))
      >>> a
      array([False, False])
    • On the contrary, torch.put_() requires the indices (index) to share the same shape as the value (source).

    As broadcasting is convenient and very common throughout the spec, IMO I'd specify value can be broadcasted to indices.

    Regardless I think we'd disallow broadcast-incompatible shapes, and value.size > indices.size scenarios.

  2. Should xp.put() return an array? What are the expectations for in-place and out-of-place behaviour?

    • The spec currently always(?) returns arrays for its functions, which seems a nice cadence to maintain.
    • Notably NumPy for its top-level np.put() currently only acts on the array in-place and does not return the modified array, whereas PyTorch has its put-like functions/methods return the modified array (some functions/methods also acting in-place).

    If we mandate xp.put() is to return the modified input array, we could leave in-place behaviour out-of-scope, or slap a copy keyword like we do for xp.asarray() and xp.reshape().

    Worth noting that the name put() also suggests in-place behaviour at this point.

lezcano commented 1 year ago

Some thoughts on @honno's points

  1. I think that's because all those objects are array_likes in NumPy. The same happens for any function that accepts an array in NumPy.
  2. SGTM but perhaps extending it to "indices of dimension at most 1". I believe PyTorch accepts wlog any contiguous array of any size, but I've never seen it used with anything but arrays of dim 0 or 1
  3. Checking for repeated values is indeed too costly. I think it should be left unspecified.
  4. Either SGTM
  5. In PyTorch we return an array for in-place ops, following the C++ convention. This is sometimes used to be able to chain in-place ops. I think it'd be good to mandate this all throughout the API.
arogozhnikov commented 1 year ago
  1. you discussed broadcasting values to indices. What about broadcasting indices to values? (case of specific interest to me)
>>> a = np.asarray([[True, True]])
>>> np.put(a, np.asarray([0]), np.asarray([[False, True]]))
>>> a
array([[False,  True]])
lezcano commented 1 year ago

Broadcasting indices should give, by definition, repeated indices, which should invoke UB (which value is written to the index 0 in your example?).

arogozhnikov commented 1 year ago

no, see my example above. values consist of one row, and index specifies that first row of values should be assigned to first row of result. I am not sure this is strictly the case of broadcasting, but that's a common thing to do.

Compare with:

matrix_n_by_n[[1, 2, 6]] = matrix_3_by_n
asmeurer commented 1 year ago

We should clarify in the spec that behavior on out-of-bounds indices is unspecified. The take spec currently doesn't say anything about this (I'm assuming this is behavior we want since we already say this for basic integer indexing.

rgommers commented 1 year ago

I had a look at implementations in libraries, some updates on what's in the issue description:

def put(*args, **kwargs):
  raise NotImplementedError(
    "jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
    "For functional approaches to updating array values, see jax.numpy.ndarray.at: "
    "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")

for similar reasons as it avoids other in-place APIs (xref design_topics/copies_views_and_mutation).

So I think we should consider the addition of put feasible in principle but blocked right now. The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.

lezcano commented 1 year ago

@arogozhnikov I think your example doesn't do what you think it does. Consider

>>> x = np.asarray([[0, 1]])
>>> np.put(x, [0], [[2, 3]])
>>> x
array([[2, 1]])

In this case, it's not that it's being broadcasted, but that np.put just considers the first ind.size elements of v. See https://github.com/numpy/numpy/blob/6073588dd73809a60819d71b9527194195f73f08/numpy/core/src/multiarray/item_selection.c#L439

In general, you are talking about "rows", but put just sees the array as a flat chunk of memory, so there is no concept of rows and columns for this function.

arogozhnikov commented 1 year ago

indeed, for some reason I though it is somewhat a shortcut for x[ind] = val, but docs say that it operates on flat array. My bad!

kgryte commented 1 year ago

So I think we should consider the addition of put feasible in principle but blocked right now. The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.

Given the above, I will go ahead and close this issue, as we are unlikely to make progress on put in the near term. This issue can be reopened and revisited once we have a better handle on a path forward.

mdhaber commented 4 months ago

Can this issue be reopened for the take_along_axis portion? As noted in https://github.com/data-apis/array-api/pull/416#issuecomment-1177897469, the functionality is different from take, and although it can be implemented in terms of take (postscript), I haven't found a trivial way. It also looks like there is broad support now - in addition to the implementations mentioned in the top post, there are jax.numpy.take_along_axis and torch.take_along_dim.


In case it is relevant, here is an array-API compatible version of take_along_axis I've been using.

```python3 import numpy as np import array_api_strict from array_api_compat import array_namespace def xp_swapaxes(a, axis1, axis2, *, xp=None): xp = array_namespace(a) if xp is None else xp axes = list(range(a.ndim)) axes[axis1], axes[axis2] = axes[axis2], axes[axis1] a = xp.permute_dims(a, axes) return a def xp_take_along_axis(arr, indices, axis, *, xp=None): xp = array_namespace(arr) if xp is None else xp arr = xp_swapaxes(arr, axis, -1, xp=xp) indices = xp_swapaxes(indices, axis, -1, xp=xp) m = arr.shape[-1] n = indices.shape[-1] shape = list(arr.shape) shape.pop(-1) shape = shape + [n,] arr = xp.reshape(arr, (-1,)) indices = xp.reshape(indices, (-1, n)) offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices) out = xp.reshape(out, shape) return xp_swapaxes(out, axis, -1, xp=xp) rng = np.random.default_rng() x = rng.random(size=(1000, 1000)) xp = array_api_strict x = xp.asarray(x) j = xp.argsort(x, axis=-1) res = xp_take_along_axis(x, j, axis=-1) ref = xp.sort(x, axis=-1) assert xp.all(res == ref) ```
shoyer commented 4 months ago

I would really like to the see full integer coordinate-based indexing supported: https://github.com/data-apis/array-api/issues/669