data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

`torch` support indexing with negative step #144

Open mdhaber opened 3 months ago

mdhaber commented 3 months ago

The array API standard seems to support negative step.

The basic slice syntax is i:j:k where i is the starting index, j is the stopping index, and k is the step (k != 0).

But array-api-compat.torch tensors do not:

from array_api_compat import torch
x = torch.arange(10)
x[::-1]  # ValueError: step must be greater than zero

Adding support for negative step would be appreciated! (In the meantime, I can use flip.) Thanks for considering it.

asmeurer commented 3 months ago

It would have to be via a wrapper function, since we don't wrap the tensor objects. Does torch have a function that reverses a tensor? Maybe it should be done by manipulating the strides?

mdhaber commented 3 months ago

It doesn't look like it. One of the most recent requests for this feature is pytorch/pytorch#59786, and it links to one of the very old requests, pytorch/pytorch#229. Looks like it's just not implemented. flip is the substitute, but it makes a copy.

asmeurer commented 2 months ago

Hmm. If you try to do this with strides, you get an error:

>>> a = torch.arange(10)
>>> torch.as_strided(a, a.shape, (-8,))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: as_strided: Negative strides are not supported at the moment, got strides: [-8]

So I think torch just fundamentally doesn't support reversed views right now. The best we can do is a helper to translate a slice into a transformation with flip, which would be a copy as you noted.

What generality of slices do you need support for? Steps less than -1? Start and stop? Slices in multidimensional indices?

mdhaber commented 2 months ago

For the time being I've used flip and probably won't change it by the time a patch is available in SciPy. So from that perspective I don't need anything, and I just thought I should report the issue. But in the context I encountered this issue, it was just [::-1].

mdhaber commented 2 months ago

This reminds me of how we've discussed getting around the fact that JAX can't mutate arrays - we've discussed a function for mutating elements at specified indices if possible and copying otherwise (maybe with the JAX .at syntax). Is array_api_compat the place for that, or does each project need to implement its own?

ISTM we might want a similar thing here, because it might not be OK to copy if the user is expecting a view.

Another possibility is just re-raising, explaining why array_api_compat can't implement negative steps and recommending flip if copies are OK.

asmeurer commented 2 months ago

array-api-compat generally isn't the place to implement new APIs that aren't in the standard (see https://data-apis.org/array-api-compat/#scope).

However, something that could be in scope for array-api-compat is helper functionality to workaround how different libraries handle copies vs. views. I don't know what that would look like exactly, but if you have any proposals of things that could help I'm open to hearing them. We should open a new issue to discuss this.

asmeurer commented 2 months ago

At any rate, this issue makes me realize that a function that converts a slice into a strides and offset could be a generally useful thing. I might implement it in ndindex at some point https://github.com/Quansight-Labs/ndindex/issues/180.