Closed thomasjpfan closed 1 year ago
Is there an upstream pytorch issue for adding this to pytorch itself? It seems to be impossible to search for in the GitHub issue tracker. @lezcano
Is this issue still up to date? It seems that the latest release of array-api-compat (1.3) makes xp.take
accept an axis
parameter but also makes it mandatory:
>>> import array_api_compat
>>> from array_api_compat import get_namespace
>>> array_api_compat.__version__
'1.3'
>>> import torch
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.5549, 0.1046, 0.8453],
[ 0.8627, -0.7882, -0.6627],
[-1.4976, -0.1424, 0.4018]])
>>> indices = torch.tensor([1, 0])
>>> xp = get_namespace(x)
>>> xp.take(x, indices)
Traceback (most recent call last):
Cell In[15], line 1
xp.take(x, indices)
TypeError: take() missing 1 required keyword-only argument: 'axis'
>>> xp.take(x, indices, axis=0)
tensor([[ 0.8627, -0.7882, -0.6627],
[ 0.5549, 0.1046, 0.8453]])
>>> xp.take(x, indices, axis=1)
tensor([[ 0.1046, 0.5549],
[-0.7882, 0.8627],
[-0.1424, -1.4976]])
The fact that it's mandatory was a bit surprising to me, but apparently the signature is in line with the specification:
>>> xp.take?
Signature: xp.take(x: 'array', indices: 'array', /, *, axis: 'int', **kwargs) -> 'array'
Docstring: <no docstring>
File: ~/mambaforge/envs/dev/lib/python3.11/site-packages/array_api_compat/torch/_aliases.py
Type: function
take(x: array, indices: array, /, *, axis: int) → array
@ogrisel Only when the input array is one-dimensional is the axis
kwarg optional. Otherwise, take
is equivalent to integer indexing on a multi-dimensional array, in which one would explicitly indicate the axis
to index.
This looks like a bug in the spec. The text says passing it is optional, but the Python signature does not allow for that.
Yes, this is unfortunate, as, by making it optional for one-dimensional arrays, we don't have a great option for the signature, as setting the default value to None
makes it seem as if the argument is optional for multi-dimensional arrays. My feeling is it would be better if the axis
argument is required for all arrays in order to resolve this ambiguity.
The discussion in https://github.com/data-apis/array-api/pull/416 explicitly states that axis
is optional, and that was supported by multiple people. So I'd consider this a bug in that PR and in the spec, and the correct resolution seems to me to be to add the missing = None
to the signature. The docs are clear enough, so there isn't much of an ambiguity.
The ambiguity stems from the type signature: the signature is not able to encode optionality and "requiredness" at the same time (or based on the input array shape).
We made the axis
kwarg optional as a convenience; however, at the time, I had omitted using None
as the default to satisfy the general case in which the kwarg is required for >1D. Obviously, this doesn't work for the scenario where axis
is optional; hence, my statement concerning the unfortunate aspect of the signature.
Yes, I understand what's going on - it seems clear to me that the spec has a bug that has to be resolved one way or the other, and the written agreement is for optionality for 1-D arrays, hence changing the signature to axis : int = None
. It's the path of least resistance, and avoids the surprise that @ogrisel expressed above. We already considered scikit-learn's needs in https://github.com/data-apis/array-api/pull/416#issuecomment-1177628306; it's mostly 1-D arrays.
Yes, this is an issue with the current spec. With the current signature, take(x, indices, *, axis)
, axis
is always required, due to the way the keyword-only syntax works.
The problem is that people (such as myself) copy the signatures from the spec directly. So if the Python signature doesn't allow axis
to be optional, it won't be, because that's the exact signature I used.
The alternative signature, take(x, indices, *, axis=None)
is easy to make work in the function logic:
def take(x, indices, *, axis=None):
if x.ndim > 1 and axis is None:
raise ValueError("axis must be provided when ndim > 1")
...
This is somewhat similar to arange
where it's impossible to encode the "true" signature into the Python signature so we have to do what is as close as possible.
I opened a PR correcting the signature for take
: https://github.com/data-apis/array-api/pull/644.
I updated numpy.array_api here https://github.com/numpy/numpy/pull/24187
I'm unclear if I need to do anything for the compat library for numpy/cupy. The axis argument is not optional for ndim > 1 in numpy (it flattens). This generally the sort of thing I would expect the strict numpy.array_api implementation to catch, whereas in the compat library, we allow things that aren't strictly disallowed and do things like pass additional keyword arguments through.
Torch does need to be fixed though because it just wraps torch.index_select
which requires the axis
argument.
I'm unclear if I need to do anything for the compat library for numpy/cupy.
Probably not, I think it works as is.
I suspect take's
axis
argument will be needed at some point. Can we add a simple implementation for PyTorch?