data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
75 stars 25 forks source link

Implement torch.take with `axis` argument #34

Closed thomasjpfan closed 1 year ago

thomasjpfan commented 1 year ago

I suspect take's axis argument will be needed at some point. Can we add a simple implementation for PyTorch?

def take(array, indices, *, axis):
    key = [slice(None)] * array.ndim
    key[axis] = indices
    return array[key]
asmeurer commented 1 year ago

Is there an upstream pytorch issue for adding this to pytorch itself? It seems to be impossible to search for in the GitHub issue tracker. @lezcano

lezcano commented 1 year ago

https://pytorch.org/docs/stable/generated/torch.index_select.html

ogrisel commented 1 year ago

Is this issue still up to date? It seems that the latest release of array-api-compat (1.3) makes xp.take accept an axis parameter but also makes it mandatory:

>>> import array_api_compat
>>> from array_api_compat import get_namespace
>>> array_api_compat.__version__
'1.3'
>>> import torch
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.5549,  0.1046,  0.8453],
        [ 0.8627, -0.7882, -0.6627],
        [-1.4976, -0.1424,  0.4018]])
>>> indices = torch.tensor([1, 0])
>>> xp = get_namespace(x)
>>> xp.take(x, indices)
Traceback (most recent call last):
  Cell In[15], line 1
    xp.take(x, indices)
TypeError: take() missing 1 required keyword-only argument: 'axis'
>>> xp.take(x, indices, axis=0)
tensor([[ 0.8627, -0.7882, -0.6627],
        [ 0.5549,  0.1046,  0.8453]])
>>> xp.take(x, indices, axis=1)
tensor([[ 0.1046,  0.5549],
        [-0.7882,  0.8627],
        [-0.1424, -1.4976]])

The fact that it's mandatory was a bit surprising to me, but apparently the signature is in line with the specification:

>>> xp.take?
Signature: xp.take(x: 'array', indices: 'array', /, *, axis: 'int', **kwargs) -> 'array'
Docstring: <no docstring>
File:      ~/mambaforge/envs/dev/lib/python3.11/site-packages/array_api_compat/torch/_aliases.py
Type:      function
take(x: array, indices: array, /, *, axis: int) → array
kgryte commented 1 year ago

@ogrisel Only when the input array is one-dimensional is the axis kwarg optional. Otherwise, take is equivalent to integer indexing on a multi-dimensional array, in which one would explicitly indicate the axis to index.

asmeurer commented 1 year ago

This looks like a bug in the spec. The text says passing it is optional, but the Python signature does not allow for that.

kgryte commented 1 year ago

Yes, this is unfortunate, as, by making it optional for one-dimensional arrays, we don't have a great option for the signature, as setting the default value to None makes it seem as if the argument is optional for multi-dimensional arrays. My feeling is it would be better if the axis argument is required for all arrays in order to resolve this ambiguity.

rgommers commented 1 year ago

The discussion in https://github.com/data-apis/array-api/pull/416 explicitly states that axis is optional, and that was supported by multiple people. So I'd consider this a bug in that PR and in the spec, and the correct resolution seems to me to be to add the missing = None to the signature. The docs are clear enough, so there isn't much of an ambiguity.

kgryte commented 1 year ago

The ambiguity stems from the type signature: the signature is not able to encode optionality and "requiredness" at the same time (or based on the input array shape).

We made the axis kwarg optional as a convenience; however, at the time, I had omitted using None as the default to satisfy the general case in which the kwarg is required for >1D. Obviously, this doesn't work for the scenario where axis is optional; hence, my statement concerning the unfortunate aspect of the signature.

rgommers commented 1 year ago

Yes, I understand what's going on - it seems clear to me that the spec has a bug that has to be resolved one way or the other, and the written agreement is for optionality for 1-D arrays, hence changing the signature to axis : int = None. It's the path of least resistance, and avoids the surprise that @ogrisel expressed above. We already considered scikit-learn's needs in https://github.com/data-apis/array-api/pull/416#issuecomment-1177628306; it's mostly 1-D arrays.

asmeurer commented 1 year ago

Yes, this is an issue with the current spec. With the current signature, take(x, indices, *, axis), axis is always required, due to the way the keyword-only syntax works.

The problem is that people (such as myself) copy the signatures from the spec directly. So if the Python signature doesn't allow axis to be optional, it won't be, because that's the exact signature I used.

The alternative signature, take(x, indices, *, axis=None) is easy to make work in the function logic:

def take(x, indices, *, axis=None):
    if x.ndim > 1 and axis is None:
        raise ValueError("axis must be provided when ndim > 1")
    ...

This is somewhat similar to arange where it's impossible to encode the "true" signature into the Python signature so we have to do what is as close as possible.

kgryte commented 1 year ago

I opened a PR correcting the signature for take: https://github.com/data-apis/array-api/pull/644.

asmeurer commented 1 year ago

I updated numpy.array_api here https://github.com/numpy/numpy/pull/24187

I'm unclear if I need to do anything for the compat library for numpy/cupy. The axis argument is not optional for ndim > 1 in numpy (it flattens). This generally the sort of thing I would expect the strict numpy.array_api implementation to catch, whereas in the compat library, we allow things that aren't strictly disallowed and do things like pass additional keyword arguments through.

Torch does need to be fixed though because it just wraps torch.index_select which requires the axis argument.

rgommers commented 1 year ago

I'm unclear if I need to do anything for the compat library for numpy/cupy.

Probably not, I think it works as is.