pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.64k stars 1.09k forks source link

`as_shared_dtype` converts scalars to 0d `numpy` arrays if chunked `cupy` is involved #7721

Open keewis opened 1 year ago

keewis commented 1 year ago

I tried to run where with chunked cupy arrays:

In [1]: import xarray as xr
   ...: import cupy
   ...: import dask.array as da
   ...: 
   ...: arr = xr.DataArray(cupy.arange(4), dims="x")
   ...: mask = xr.DataArray(cupy.array([False, True, True, False]), dims="x")

this works:

In [2]: arr.where(mask)
Out[2]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

this fails:

In [4]: arr.chunk().where(mask).compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 arr.chunk().where(mask).compute()

File ~/repos/xarray/xarray/core/dataarray.py:1095, in DataArray.compute(self, **kwargs)
   1076 """Manually trigger loading of this array's data from disk or a
   1077 remote source into memory and return a new array. The original is
   1078 left unaltered.
   (...)
   1092 dask.compute
   1093 """
   1094 new = self.copy(deep=False)
-> 1095 return new.load(**kwargs)

File ~/repos/xarray/xarray/core/dataarray.py:1069, in DataArray.load(self, **kwargs)
   1051 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1052     """Manually trigger loading of this array's data from disk or a
   1053     remote source into memory and return this array.
   1054 
   (...)
   1067     dask.compute
   1068     """
-> 1069     ds = self._to_temp_dataset().load(**kwargs)
   1070     new = self._from_temp_dataset(ds)
   1071     self._variable = new._variable

File ~/repos/xarray/xarray/core/dataset.py:752, in Dataset.load(self, **kwargs)
    749 import dask.array as da
    751 # evaluate all the dask arrays simultaneously
--> 752 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    754 for k, data in zip(lazy_data, evaluated_data):
    755     self.variables[k].data = data

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    597     keys.append(x.__dask_keys__())
    598     postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
    601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File <__array_function__ internals>:180, in where(*args, **kwargs)

File cupy/_core/core.pyx:1723, in cupy._core.core._ndarray_base.__array_function__()

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/cupy/_sorting/search.py:211, in where(condition, x, y)
    209 if fusion._is_fusing():
    210     return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)

File cupy/_core/_kernel.pyx:1287, in cupy._core._kernel.ufunc.__call__()

File cupy/_core/_kernel.pyx:160, in cupy._core._kernel._preprocess_args()

File cupy/_core/_kernel.pyx:146, in cupy._core._kernel._preprocess_arg()

TypeError: Unsupported type <class 'numpy.ndarray'>

this works again:

In [7]: arr.chunk().where(mask.chunk(), cupy.array(cupy.nan)).compute()
Out[7]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

And other methods like fillna show similar behavior.

I think the reason is that this: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L195 is not sufficient to detect cupy beneath other layers of duckarrays (most commonly dask, pint, or both). In this specific case we could extend the condition to also match chunked cupy arrays (like arr.cupy.is_cupy does, but using is_duck_dask_array), but this will still break for other duckarray layers or if dask is not involved, and we're also in the process of moving away from special-casing dask. So short of asking cupy to treat 0d arrays like scalars I'm not sure how to fix this.

cc @jacobtomlinson

jacobtomlinson commented 1 year ago

Ping @leofang in case you have thoughts?

leofang commented 1 year ago

Sorry that I missed the ping, Jacob, but I'd need more context for making any suggestions/answers 😅 Is the question about why CuPy wouldn't return scalars?

keewis commented 1 year ago

The issue is that here: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L193-L206 we try to convert everything to the same dtype, casting numpy and python scalars to an array. The latter is important, because e.g. numpy.array_api.where only accepts arrays as input.

However, detecting cupy beneath (multiple) layers of duckarrays is not easy, which means that for example passing a pint(dask(cupy)) array together with scalars will currently cast the scalars to 0-d numpy arrays, while passing a cupy array instead will result in 0-d cupy arrays.

My naive suggestion was to treat np.int64(0) and np.array(0, dtype="int64") the same, where at the moment the latter would fail for the same reason as np.array([0], dtype="int64").

leofang commented 1 year ago

Thanks, Justus, for expanding on this. It sounds to me the question is "how do we cast dtypes when multiple array libraries are participating in the same computation?" and I am not sure I am knowledgable enough to make any comment.

From the array API point of view, long long ago we decided that this is UB (undefined behavior), meaning it's completely up to each library to decide what to do. You can raise or come up with a special rule that you can make sense of.

It sounds like Xarray has some machinery to deal with this situation, but you'd rather prefer to not keep special-casing for a certain array library? Am I understanding it right?

keewis commented 1 year ago

there's two things that happen in as_shared_dtype (which may not be good design, and we should probably consider splitting it into as_shared_dtype and as_compatible_arrays or something): first, we cast everything to an array, then decide on a common dtype and cast everything to that.

The latter could easily be done by using numpy scalars, which as far as I can tell would be supported by most array libraries, including cupy. However, the reason we need to cast to arrays is that the array API (i.e. __array_namespace__) does not allow using scalars of any type, e.g. np.array_api.where (this is important for libraries that don't implement __array_ufunc__ / __array_function__). To clarify, what we're trying to support is something like

import numpy.array_api as np
np.where(cond, cupy_array, python_scalar)

which (intentionally?) does not work.

At the moment, as_shared_dtype (or, really, the hypothetical as_compatible_arrays) correctly casts python_scalar to a 0-d cupy.array for the example above, but if we were to replace cupy_array with chunked_cupy_array or chunked_cupy_array_with_units, the special casing for cupy stops to work and scalars will be cast to 0-d numpy.array. Conceptually, I tend to think of 0-d arrays as equivalent to scalars, hence the suggestion to have cupy treat numpy scalars and 0-d numpy.array the same way (I don't follow the array api closely enough to know whether that was already discussed and rejected).

So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__, given that stopping to do so would be a major breaking change?

Of course, I would prefer removing the special casing for specific libraries, but I wouldn't be opposed to keeping the existing one. I guess as a short-term fix we could just pull _meta out of duck dask arrays and determine the common array type for that (the downside is that we'd add another special case for dask, which in another PR we're actually trying to remove).

As a long-term fix I guess we'd need to revive the stalled nested duck array discussion.

rgommers commented 1 year ago

So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__, given that stopping to do so would be a major breaking change?

I was considering this question for SciPy (xref scipy#18286) this week, and I think I'm happy with this strategy:

  1. Cast all "array-like" inputs like Python scalars, lists/sequences, and generators, to numpy.ndarray.
  2. Require "same array type" input, forbid mixing numpy-cupy, numpy-pytorch, cupy-pytorch, etc. - this will raise an exception
  3. As a result, cupy-pyscalar and pytorch-pyscalar will also raise an exception.

What that results in is an API that's backwards-compatible for numpy and array-like usage, and much stricter when using other array libraries. That strictness to me is a good thing, because:

keewis commented 12 months ago

So, after thinking about this for (quite) some time, it appears that one way or another we need to figure out the appropriate base array type of the nested array (regardless of whether or not we disallow passing python scalars to the xarray API... though since it is a breaking change I don't think we will do that).

I've come up with a (recursive) way of extracting the nesting structure in keewis/nested-duck-arrays, which we should be able to use to figure out the leaf array type and keep the current hack until we figure out how to resolve the issue without it.

yt87 commented 5 months ago

Would this be an acceptable, if temporary, fix for #9195? Modified code in as_shared_dtype:

     array_type_cupy = array_type("cupy")
     # temporary fix
     import nested_duck_arrays.dask

     def _maybe_cupy(seq):
         return any(isinstance(x, array_type_cupy) or 
             is_duck_dask_array(x) and x.__duck_arrays__()[-1].__module__ == 'cupy'
             for x in seq)

     # if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
     if _maybe_cupy(scalars_or_arrays):
         # end of fix
         import cupy as cp
keewis commented 5 months ago

I'd go with something like

import nested_duck_arrays.dask
import nested_duck_arrays
...

if any(nested_duck_arrays.first_layer(x) is array_type_cupy for x in scalars_or_arrays):
    import cupy as cp

and add nested_duck_arrays.first_layer (with maybe a better name?) which would have a fallback of returning a 1-tuple containing type of x in case x is not a duck array (I'd be happy to relatively quickly release that to PyPI / conda-forge).

We'll need to think about what to do if nested_duck_arrays is not installed, though... something like this, maybe?

try:
    from nested_duck_arrays import first_layer
except ImportError:
    def first_layer(x):
        return type(x)

Also, we'll probably want to push the contents of nested_duck_arrays.dask to dask.array.