Closed asmeurer closed 7 months ago
Oh we should also probably add a basic test for the device helpers at https://github.com/data-apis/array-api-compat/blob/main/tests/test_common.py
Apparently we've been supporting "cpu"
as a special host device here (https://github.com/data-apis/array-api-compat/pull/40). This is still being discussed for the standard https://github.com/data-apis/array-api/issues/626.
I'm actually unsure if it's a good idea for us to be supporting that here given that we haven't really agreed about it in the standard. But if we did want to support it for JAX, how would we? I don't see how to actually access jax.CpuDevice
.
jax.devices('cpu')[0]
would return the first CPU device (if available).
Removed the "cpu" device logic. It looks like that isn't going to be part of the standard. I opened https://github.com/data-apis/array-api-compat/issues/86 about removing it for cupy as well.
Unlike other modules, JAX array API support is fully in JAX itself in the
jax.experimental.array_api
submodule, so the only thing that is done here is to add JAX support to the helper functions. This also means that we do not run array-api-tests on JAX.This also makes the various
is_numpy_array
,is_cupy_array
, etc. functions public, as I noticed someone was using them on GitHub and they seem like they could be useful.