ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.19k stars 994 forks source link

[RFC] Implement the Python array API standard #48

Open lucascolley opened 11 months ago

lucascolley commented 11 months ago

The Python array API standard standardises common functionality across Python array/tensor libraries. NumPy, PyTorch and CuPy are planning to have full implementations, and Dask and JAX also have implementations in progress. You could implement this in your main namespace or a separate namespace.

Why should you do this? As well as making it easier for users to convert existing NumPy/PyTorch/CuPy code to MLX, there is potential for interoperability with other libraries. For example, from the NumPy ecosystem, SciPy and scikit-learn have partial experimental support for arrays which comply with the standard.

If you are interested in this, the consortium would love to hear feedback over at https://github.com/data-apis/consortium-feedback/. Some potential pain points, such as missing float64 support, have already been discussed very briefly in https://github.com/data-apis/array-api/issues/719.

lucascolley commented 11 months ago

quoting @awni in gh-12:

We had the luxury of picking the best from all the frameworks we've used and worked on in the past and combining them into something new.

In terms of API design (implementation aside), this is the same process as what the consortium did in creation of the standard - I imagine it would be fruitful to discuss where the differences are and why one might be preferable.

asmeurer commented 11 months ago

My talk at SciPy 2023 is useful if you want to know more about the array API.

arpan-dhatt commented 11 months ago

+1 Implementing this standard would help libraries like einops "just work" when dealing with MLX arrays. That particular library has been cropping up a lot for array reshape/transpose/stack, etc in the HyenaDNA model I'm trying to port.

It can also help clear up ambiguities in API design such as #113 and create a "checklist" of basic ops that should be implemented but are not yet, like linspace.

Edit: and just a thought, if there's any time to break backwards compatibility to make the Array API "first class" for MLX, it's now, at version 0.0.4 🙃

awni commented 11 months ago

Are there any differences between the Python array API standard and NumPy or is the standard basically a subset of NumPy? If it's the latter, then I would say we are already on track to implement the standard.

Either way though, we will definitely take it into consideration as we continue to update the API!

lucascolley commented 11 months ago

Are there any differences between the Python array API standard and NumPy or is the standard basically a subset of NumPy?

Here is the tracking issue for support in the main NumPy namespace (making the NumPy API a superset of the standard). Some decisions were made to differ from NumPy where other libraries seemed to have improved upon NumPy, but it is quite similar. @mtsokol and @rgommers are working on a proposal to continue the work into making NumPy a superset.

Either way though, we will definitely take it into consideration as we continue to update the API!

Great, I'd definitely be keen to try to get MLX arrays working in SciPy! As @arpan-dhatt mentioned, if you think that compliance seems plausible and a good idea, now is probably the time to check for BC breaking changes.

rgommers commented 11 months ago

For context: in numpy 1.2x.y, there are differences. In numpy 2.0 (branching within a month, tentative release date end of Feb'24) there are a lot of API changes, additions and a number of backwards compatibility breaks to ensure that NumPy's main namespace and the fft and linalg modules will be compliant with the array API standard. The most important bc-breaking change (type promotion rules, NEP 55) was planned anyway independent of array API support - and makes NumPy more consistent and align better with JAX/PyTorch behavior.

Right now, for numpy 1.2x.y, the differences are bridged by the array-api-compat package. Longer-term, the need for that should go away. Having array API standard support in MLX would be quite nice - in particular because it'd then be possible to add support for MLX in SciPy, scikit-learn & co.

EwoutH commented 8 months ago

@awni thanks for looking into this! Do you have anything to share as of now?

This is currently the second highest upvoted open issue by the way!

awni commented 8 months ago

thanks for looking into this! Do you have anything to share as of now?

I'm sorry we haven't spent much time on this. I would say the following though, we intend to follow NumPy. So by transitivity if NumPy follows the Python API standard so will MLX. So we'd happily take PRs that move MLX to be more inline with NumPy, and will continue to work towards that ourselves.

lucascolley commented 8 months ago

if NumPy follows the Python API standard so will MLX

Watch this space: https://numpy.org/neps/nep-0056-array-api-main-namespace.html

ogrisel commented 8 months ago

we intend to follow NumPy. So by transitivity if NumPy follows the Python API standard so will MLX

That's good to know but there are things in MLX that do not exist in NumPy such as the stream parameter:

This is related to the device concept in Array API:

Note however that the concept of stream/queue control was deemed out-of-scope for the Array API. Instead, array constructor typically accept a device= kwarg and arrays expose a .device attribute.

Also note that NumPy will remain a CPU-only library for the foreseeable future but it might expose device keywords/attributes for the sake of compatibility:

ogrisel commented 8 months ago

Similarly, NumPy is fundamentally an eager-evaluation library while the MLX library is lazy by default and exposes an explicit mx.eval function:

that can be used as a synchronization primitive.

On the other hand, Array API supports lazyness but does not specify a library-agnostic evaluation function as part of standard (yet). The only standard way to explicitly trigger evaluation are via dunder methods such as __float__ , __bool__ and the likes: