data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
76 stars 26 forks source link

Add functions such as `take` to existing Array API spec if not implemented yet? #23

Closed thomasjpfan closed 1 year ago

thomasjpfan commented 1 year ago

Should array-api-compat "update" the namespace for existing Array API arrays to the most recent spec? For example:

# Assume that `xp.take` is not implemented in the installed CuPy version
import cupy.array_api as xp
import array_api_compat

X = xp.asarray([1.0, 2.0])
xp = array_api_compat.get_namespace(X)

# Should this always be true?
assert hasattr(xp, "take")
rgommers commented 1 year ago

I'd say yes, that should be a goal. Taking over the implementation from https://github.com/cupy/cupy/pull/7432 would be good I'd think.

asmeurer commented 1 year ago

I'd say changes like this are definitely in scope. We will add it eventually once we add the 2023 version of the spec, but we can add it sooner if it would help.

And as Ralf pointed out, we also still need to update numpy.array_api and cupy.array_api with 2023 functionality.

rgommers commented 1 year ago

Let's add it now indeed - scikit-learn was needing take which is why we prioritized adding it to the standard. It's not draft by the way, it was in the 2022 standard: https://data-apis.org/array-api/latest/changelog.html#v2022-12

asmeurer commented 1 year ago

Yes, sorry I meant to type 2022 above, not 2023.

asmeurer commented 1 year ago

The main challenge here is there aren't any tests for take yet in the test suite. @honno if you could work on adding take to https://github.com/data-apis/array-api-tests/pull/165 (or a new PR) that would help. I noticed that torch.take will need wrapping because it doesn't have an axis keyword (if you don't use the axis keyword the existing torch wrapper should already have a take function that will work).

NumPy and CuPy already have take. Adding it to numpy.array_api and cupy.array_api is a different story. We would need to upstream it there.

Presently this compat library doesn't wrap or do anything existing array API compatible libraries like numpy.array_api. It just returns them as-is. So there are a few options here:

The difference for these two is really between using array_api_compat.take vs. xp.take.

rgommers commented 1 year ago

We will eventually upstream 2022 spec support to numpy.array_api (and cupy.array_api will presumably follow suit). However, this isn't being worked on yet, and even once it happens, it will require a numpy release to be usable.

This is already being done for take:

Should array-api-compat "update" the namespace for existing Array API arrays to the most recent spec?

I'm not clear on the need for this request for xp.take though. The point of array-api-compat is to extend the main namespaces and avoid the need to use the separate submodules (numpy.array_api and cupy.array_api). @thomasjpfan can you comment on this?

thomasjpfan commented 1 year ago

I'm not clear on the need for this request for xp.take though.

If the goal is to remove numpy.array_api and cupy.array_api, then take does not need to be added to existing Array API namespaces. For libraries that support Array API and need take, they will require the Array API library to be updated with the v2022 spec.

It does raise the question about how versioning works here. Consider:

# custom_array_library implements v2021 spec, but not v2022.
import custom_array_library.array_api as xp
import array_api_compat

X = xp.asarray([1.0, 2.0])
xp = array_api_compat.get_namespace(X)

# Which spec is `xp` supporting? v2021 or v2022?
xp

The easiest solution is to error on array_api_compat.get_namespace because the latest spec (v2022) is not supported by the array library. The less appealing answer is to update the namespace to the v2022 spec (which is this feature request).

thomasjpfan commented 1 year ago

From the weekly call, we think it's best to error in array_api_compat.get_namespace when the input array does not support the latest spec.

asmeurer commented 1 year ago

So for now (i.e., in #25) I am not going to do anything about take, since it is already present in numpy, cupy, and torch. The torch take does not have the axis keyword argument, but @thomasjpfan confirmed today that this is not an issue for him. Going forward we will:

Regarding the version, I think the primary benefit of it is to get a better error message when using an array API namespace that is too old. Virtually every change in the spec is additive. Meaning there's no reason for something like get_namespace(np.array(...), version='2021.12') to return anything different from get_namespace(np.array(...), version='2022.12'). The 2022 compliant namespace will also be 2021 compliant.

The bigger question here is about numpy.array_api. Due to its strictness, we might want to make it so that you can get a 2021 strictly compliant version of it that doesn't include any 2022 stuff. As I said, no work has really happened yet on adding 2022 support, except for a few offshoot PRs like the one mentioned above adding take. I think it would be nice to have this behavior there, but I'll have to think about how easy it is to implement. But this really should be discussed on the NumPy repo, not here.

For this compat library, we are taking a much more pragmatic approach.

asmeurer commented 1 year ago

Implemented api_version (currently only supporting '2021.12') to get_namespace() in #25.

asmeurer commented 1 year ago

Based on the discussion, I don't think there's anything to do here, since take is already implemented in numpy, cupy, and torch, and has been implemented in the git numpy and cupy array_api submodules (with full 2022.12 support coming later). I've added the version flag to get_namespace and it only accepts 2021.12 for now.