data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
70 stars 22 forks source link

Add basic JAX support #84

Closed asmeurer closed 7 months ago

asmeurer commented 7 months ago

Unlike other modules, JAX array API support is fully in JAX itself in the jax.experimental.array_api submodule, so the only thing that is done here is to add JAX support to the helper functions. This also means that we do not run array-api-tests on JAX.

This also makes the various is_numpy_array, is_cupy_array, etc. functions public, as I noticed someone was using them on GitHub and they seem like they could be useful.

asmeurer commented 7 months ago

Oh we should also probably add a basic test for the device helpers at https://github.com/data-apis/array-api-compat/blob/main/tests/test_common.py

asmeurer commented 7 months ago

Apparently we've been supporting "cpu" as a special host device here (https://github.com/data-apis/array-api-compat/pull/40). This is still being discussed for the standard https://github.com/data-apis/array-api/issues/626.

I'm actually unsure if it's a good idea for us to be supporting that here given that we haven't really agreed about it in the standard. But if we did want to support it for JAX, how would we? I don't see how to actually access jax.CpuDevice.

jakevdp commented 7 months ago

jax.devices('cpu')[0] would return the first CPU device (if available).

asmeurer commented 7 months ago

Removed the "cpu" device logic. It looks like that isn't going to be part of the standard. I opened https://github.com/data-apis/array-api-compat/issues/86 about removing it for cupy as well.