Open yt87 opened 3 weeks ago
Hi @yt87, thanks for the bug report with the minimal example. I can reproduce the same TypeError on my end locally too.
My initial impression is that this might require some fixes on the dask side, I see some similar issues before, e.g. https://github.com/dask/dask/issues/9315, that might point to some ufunc operations not working with a CuPy backend yet. If I run the following line without dask chunks, it seems to work:
ds = xr.DataArray([1, 2, cupy.nan]).as_cupy().sum(min_count=1)
print(ds)
# <xarray.DataArray ()> Size: 8B
# array(3.)
Do you need to do the sum(min_count=1)
operation using dask chunks? If you put the .compute()
before .sum()
, this would work:
ds = xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().compute()
ds.sum(min_count=1)
Though that assumes that your actual array isn't too large to fit in GPU memory. If it is too large, you might need to parallelize the sum
computation without dask by doing it manually yourself as a workaround.
It is np.nan
that causes the error:
print(xr.DataArray([1, 2, 3]).chunk(dim_0=1).as_cupy().sum(min_count=1))
<xarray.DataArray 'asarray-60e02971486c2931a91b659a5bdc6e30' ()> Size: 8B
dask.array<sum-aggregate, shape=(), dtype=int64, chunksize=(), chunktype=cupy.ndarray>
print(xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1))
<xarray.DataArray 'asarray-75d4a7ce4023e88c4c5563214cb235b4' ()> Size: 8B
dask.array<where, shape=(), dtype=float64, chunksize=(), chunktype=numpy.ndarray>
My use case: I have a large TYX array ~12GB. For some time values, all the data is missing, I want the sum to return nan. When there is some data available, I do want the actual value. Maybe an option is to drop the missing time frames beforehand.
This fix seems to work for me:
File duck_array_ops.py
, function as_shared_dtype
# Avoid calling array_type("cupy") repeatidely in the any check
array_type_cupy = array_type("cupy")
# GT fix
import cupy as cp
#if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
if any(isinstance(x, array_type_cupy) or
is_duck_dask_array(x) and type(x._meta) == cp.ndarray
for x in scalars_or_arrays):
#import cupy as cp
xp = cp
elif xp is None:
xp = get_array_namespace(scalars_or_arrays)
What happens is that np.nan
is converted to np.ndarray
, see my previous message. This causes failure when compute
is called, expecting cupy arrays.
This is not a right fix, it makes xarray depend on cupy. There must be a better way.
We do have to handle this in xarrsy. Can you open an issue there please
The first line works, the second raises an exception
Versions:
Same thing with numpy 2.0.0