xarray-contrib / cupy-xarray

Interface for using cupy in xarray, providing convenience accessors.
https://cupy-xarray.readthedocs.io
Apache License 2.0
61 stars 12 forks source link

sum(min_count=1) raises an exception #52

Open yt87 opened 3 weeks ago

yt87 commented 3 weeks ago

The first line works, the second raises an exception

import numpy as np
import xarray as xr
import cupy_xarray

xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum().compute()
xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1).compute()

xarray.DataArray'asarray-75d4a7ce4023e88c4c5563214cb235b4'
array(3.)
Coordinates: (0)
Indexes: (0)
Attributes: (0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 6
      3 import cupy_xarray
      5 xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum().compute()
----> 6 xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1).compute()

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py:1179](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py#line=1178), in DataArray.compute(self, **kwargs)
   1154 """Manually trigger loading of this array's data from disk or a
   1155 remote source into memory and return a new array.
   1156 
   (...)
   1176 dask.compute
   1177 """
   1178 new = self.copy(deep=False)
-> 1179 return new.load(**kwargs)

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py:1147](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py#line=1146), in DataArray.load(self, **kwargs)
   1127 def load(self, **kwargs) -> Self:
   1128     """Manually trigger loading of this array's data from disk or a
   1129     remote source into memory and return this array.
   1130 
   (...)
   1145     dask.compute
   1146     """
-> 1147     ds = self._to_temp_dataset().load(**kwargs)
   1148     new = self._from_temp_dataset(ds)
   1149     self._variable = new._variable

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataset.py:863](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataset.py#line=862), in Dataset.load(self, **kwargs)
    860 chunkmanager = get_chunked_array_type(*lazy_data.values())
    862 # evaluate all the chunked arrays simultaneously
--> 863 evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
    864     *lazy_data.values(), **kwargs
    865 )
    867 for k, data in zip(lazy_data, evaluated_data):
    868     self.variables[k].data = data

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/namedarray/daskmanager.py:86](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/namedarray/daskmanager.py#line=85), in DaskManager.compute(self, *data, **kwargs)
     81 def compute(
     82     self, *data: Any, **kwargs: Any
     83 ) -> tuple[np.ndarray[Any, _DType_co], ...]:
     84     from dask.array import compute
---> 86     return compute(*data, **kwargs)

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/dask/base.py:662](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/dask/base.py#line=661), in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    659     postcomputes.append(x.__dask_postcompute__())
    661 with shorten_traceback():
--> 662     results = schedule(dsk, keys, **kwargs)
    664 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File cupy[/_core/core.pyx:1717](http://localhost:8888/_core/core.pyx#line=1716), in cupy._core.core._ndarray_base.__array_function__()

File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/cupy/_sorting/search.py:211](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/cupy/_sorting/search.py#line=210), in where(condition, x, y)
    209 if fusion._is_fusing():
    210     return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)

File cupy[/_core/_kernel.pyx:1286](http://localhost:8888/_core/_kernel.pyx#line=1285), in cupy._core._kernel.ufunc.__call__()

File cupy[/_core/_kernel.pyx:159](http://localhost:8888/_core/_kernel.pyx#line=158), in cupy._core._kernel._preprocess_args()

File cupy[/_core/_kernel.pyx:145](http://localhost:8888/_core/_kernel.pyx#line=144), in cupy._core._kernel._preprocess_arg()

TypeError: Unsupported type <class 'numpy.ndarray'>

Versions:

xr.__version__
np.__version__
cupy_xarray.__version__

'2024.6.0'
'1.26.4'
'0.1.3+9.g7fc3df5'

Same thing with numpy 2.0.0

weiji14 commented 3 weeks ago

Hi @yt87, thanks for the bug report with the minimal example. I can reproduce the same TypeError on my end locally too.

My initial impression is that this might require some fixes on the dask side, I see some similar issues before, e.g. https://github.com/dask/dask/issues/9315, that might point to some ufunc operations not working with a CuPy backend yet. If I run the following line without dask chunks, it seems to work:

ds = xr.DataArray([1, 2, cupy.nan]).as_cupy().sum(min_count=1)
print(ds)
# <xarray.DataArray ()> Size: 8B
# array(3.)

Do you need to do the sum(min_count=1) operation using dask chunks? If you put the .compute() before .sum(), this would work:

ds = xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().compute()
ds.sum(min_count=1)

Though that assumes that your actual array isn't too large to fit in GPU memory. If it is too large, you might need to parallelize the sum computation without dask by doing it manually yourself as a workaround.

yt87 commented 3 weeks ago

It is np.nan that causes the error:

print(xr.DataArray([1, 2, 3]).chunk(dim_0=1).as_cupy().sum(min_count=1))

<xarray.DataArray 'asarray-60e02971486c2931a91b659a5bdc6e30' ()> Size: 8B
dask.array<sum-aggregate, shape=(), dtype=int64, chunksize=(), chunktype=cupy.ndarray>

print(xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1))

<xarray.DataArray 'asarray-75d4a7ce4023e88c4c5563214cb235b4' ()> Size: 8B
dask.array<where, shape=(), dtype=float64, chunksize=(), chunktype=numpy.ndarray>

My use case: I have a large TYX array ~12GB. For some time values, all the data is missing, I want the sum to return nan. When there is some data available, I do want the actual value. Maybe an option is to drop the missing time frames beforehand.

yt87 commented 2 weeks ago

This fix seems to work for me: File duck_array_ops.py, function as_shared_dtype

    # Avoid calling array_type("cupy") repeatidely in the any check
    array_type_cupy = array_type("cupy")
    # GT fix
    import cupy as cp
    #if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
    if any(isinstance(x, array_type_cupy) or 
           is_duck_dask_array(x) and type(x._meta) == cp.ndarray
           for x in scalars_or_arrays):
        #import cupy as cp

        xp = cp
    elif xp is None:
        xp = get_array_namespace(scalars_or_arrays)

What happens is that np.nan is converted to np.ndarray, see my previous message. This causes failure when compute is called, expecting cupy arrays. This is not a right fix, it makes xarray depend on cupy. There must be a better way.

dcherian commented 2 weeks ago

We do have to handle this in xarrsy. Can you open an issue there please