pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Improved duck array wrapping #9798

Open slevang opened 3 days ago

slevang commented 3 days ago

Companion to #9776.

My attempts to use xarray-wrapped jax arrays turned up a bunch of limitations of our duck array wrapping. Jax is probably worst case of the major duck array types out there, because it doesn't implement either __array_function__ or __array_ufunc__ to intercept numpy calls. Cupy, sparse, and probably others do, so calling np.func(cp.ndarray) generally works fine, but this usually converts your data to numpy with jax.

A lot of this was just grepping around for hard-coded np. cases that we can easily dispatch to duck_array_ops versions. I image some of these changes could be controversial, because a number of them (notably nanmean/std/var, pad, quantile, einsum, cross) aren't technically part of the standard. See https://github.com/pydata/xarray/issues/8834#issuecomment-1998340481 for discussion about the nan-skipping aggregations.

It feels like a much better user experience though to try our best to dispatch to the correct backend, and error if the function isn't implemented, rather than blindly calling numpy. And practically, all major array backends that are feasible to wrap today (cupy, sparse, jax, cubed, arkouda, ...?) implement all of these functions.

To test, I just ran down the API list and ran most functions to see if we maintain proper wrapping. Prior to the changes here, I had 28 jax failures and 9 cupy failures, while all (non-xfailed) ones now pass.

Basically everything works except the interp/missing methods which have a lot of specialized code. Also a few odds and ends like polyfit and rank.

lucascolley commented 3 days ago

see also https://github.com/data-apis/array-api/issues/621 for the higher level discussion of nan reductions

slevang commented 1 day ago

I looked back at cupy-xarray. Now that this duck array stuff all works pretty well, I'm wondering how people feel about adding official DataArray/Dataset methods analogous to this, but in a generalized way. Something like:

def as_array_type(self, asarray: callable, **kwargs) -> Self:
# e.g. ds.to_array_type(jnp.asarray, device="gpu")

def is_array_type(self, array: type) -> bool:
# e.g. ds.is_array_type(cp.ndarray)
dcherian commented 1 day ago

methods analogous to this, but in a generalized way.

:+1: I though I'd seen an array API method for converting between compliant array types, but I can't find it now

slevang commented 1 day ago

Compliant namespaces should now implement from_dlpack which is generally the recommended conversion protocol. So I suppose we could instead pass the namespace and hard code it to use that:

def to_namespace(self, xp: ModuleType, **kwargs) -> Self:
    xp.from_dlpack(self.data)

# e.g. ds.to_namespace(cp)

But this actually doesn't work for cupy, since they're quite stringent about implicit device transfers:

TypeError: CPU arrays cannot be directly imported to CuPy. Use `cupy.array(numpy.from_dlpack(input))` instead.

Also sparse doesn't have this at all, and it wouldn't be clear whether you want a COO, DOK, etc.

slevang commented 12 hours ago

I think this is in pretty good shape now, except the question of whether to attempt any of this integration testing in CI. That could also be punted to xarray-array-testing (@keewis @TomNicholas)