Open keewis opened 1 year ago
Ping @leofang in case you have thoughts?
Sorry that I missed the ping, Jacob, but I'd need more context for making any suggestions/answers 😅 Is the question about why CuPy wouldn't return scalars?
The issue is that here: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L193-L206 we try to convert everything to the same dtype, casting numpy and python scalars to an array. The latter is important, because e.g. numpy.array_api.where
only accepts arrays as input.
However, detecting cupy
beneath (multiple) layers of duckarrays is not easy, which means that for example passing a pint(dask(cupy))
array together with scalars will currently cast the scalars to 0-d numpy
arrays, while passing a cupy
array instead will result in 0-d cupy
arrays.
My naive suggestion was to treat np.int64(0)
and np.array(0, dtype="int64")
the same, where at the moment the latter would fail for the same reason as np.array([0], dtype="int64")
.
Thanks, Justus, for expanding on this. It sounds to me the question is "how do we cast dtypes when multiple array libraries are participating in the same computation?" and I am not sure I am knowledgable enough to make any comment.
From the array API point of view, long long ago we decided that this is UB (undefined behavior), meaning it's completely up to each library to decide what to do. You can raise or come up with a special rule that you can make sense of.
It sounds like Xarray has some machinery to deal with this situation, but you'd rather prefer to not keep special-casing for a certain array library? Am I understanding it right?
there's two things that happen in as_shared_dtype
(which may not be good design, and we should probably consider splitting it into as_shared_dtype
and as_compatible_arrays
or something): first, we cast everything to an array, then decide on a common dtype and cast everything to that.
The latter could easily be done by using numpy
scalars, which as far as I can tell would be supported by most array libraries, including cupy
. However, the reason we need to cast to arrays is that the array API (i.e. __array_namespace__
) does not allow using scalars of any type, e.g. np.array_api.where
(this is important for libraries that don't implement __array_ufunc__
/ __array_function__
). To clarify, what we're trying to support is something like
import numpy.array_api as np
np.where(cond, cupy_array, python_scalar)
which (intentionally?) does not work.
At the moment, as_shared_dtype
(or, really, the hypothetical as_compatible_arrays
) correctly casts python_scalar
to a 0-d cupy.array
for the example above, but if we were to replace cupy_array
with chunked_cupy_array
or chunked_cupy_array_with_units
, the special casing for cupy
stops to work and scalars will be cast to 0-d numpy.array
. Conceptually, I tend to think of 0-d arrays as equivalent to scalars, hence the suggestion to have cupy
treat numpy
scalars and 0-d numpy.array
the same way (I don't follow the array api closely enough to know whether that was already discussed and rejected).
So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__
, given that stopping to do so would be a major breaking change?
Of course, I would prefer removing the special casing for specific libraries, but I wouldn't be opposed to keeping the existing one. I guess as a short-term fix we could just pull _meta
out of duck dask arrays and determine the common array type for that (the downside is that we'd add another special case for dask
, which in another PR we're actually trying to remove).
As a long-term fix I guess we'd need to revive the stalled nested duck array discussion.
So really, my question is: how do we support python scalars for libraries that only implement
__array_namespace__
, given that stopping to do so would be a major breaking change?
I was considering this question for SciPy (xref scipy#18286) this week, and I think I'm happy with this strategy:
numpy.ndarray
. What that results in is an API that's backwards-compatible for numpy and array-like usage, and much stricter when using other array libraries. That strictness to me is a good thing, because:
xp.asarray(a_scalar)
giving you a 0-D array of the correct type (add dtype=x.dtype
to make sure dtypes match if that matters)So, after thinking about this for (quite) some time, it appears that one way or another we need to figure out the appropriate base array type of the nested array (regardless of whether or not we disallow passing python scalars to the xarray API... though since it is a breaking change I don't think we will do that).
I've come up with a (recursive) way of extracting the nesting structure in keewis/nested-duck-arrays, which we should be able to use to figure out the leaf array type and keep the current hack until we figure out how to resolve the issue without it.
Would this be an acceptable, if temporary, fix for #9195? Modified code in as_shared_dtype
:
array_type_cupy = array_type("cupy")
# temporary fix
import nested_duck_arrays.dask
def _maybe_cupy(seq):
return any(isinstance(x, array_type_cupy) or
is_duck_dask_array(x) and x.__duck_arrays__()[-1].__module__ == 'cupy'
for x in seq)
# if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
if _maybe_cupy(scalars_or_arrays):
# end of fix
import cupy as cp
I'd go with something like
import nested_duck_arrays.dask
import nested_duck_arrays
...
if any(nested_duck_arrays.first_layer(x) is array_type_cupy for x in scalars_or_arrays):
import cupy as cp
and add nested_duck_arrays.first_layer
(with maybe a better name?) which would have a fallback of returning a 1-tuple containing type of x
in case x
is not a duck array (I'd be happy to relatively quickly release that to PyPI / conda-forge).
We'll need to think about what to do if nested_duck_arrays
is not installed, though... something like this, maybe?
try:
from nested_duck_arrays import first_layer
except ImportError:
def first_layer(x):
return type(x)
Also, we'll probably want to push the contents of nested_duck_arrays.dask
to dask.array
.
I tried to run
where
with chunkedcupy
arrays:this works:
this fails:
this works again:
And other methods like
fillna
show similar behavior.I think the reason is that this: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L195 is not sufficient to detect
cupy
beneath other layers of duckarrays (most commonlydask
,pint
, or both). In this specific case we could extend the condition to also match chunkedcupy
arrays (likearr.cupy.is_cupy
does, but usingis_duck_dask_array
), but this will still break for other duckarray layers or ifdask
is not involved, and we're also in the process of moving away from special-casingdask
. So short of askingcupy
to treat 0d arrays like scalars I'm not sure how to fix this.cc @jacobtomlinson