makepath / xarray-spatial

Raster-based Spatial Analytics for Python
https://xarray-spatial.readthedocs.io/
MIT License
828 stars 84 forks source link

Unable to use Dask Arrays originating from Xarray Datasets in Crosstab #777

Open GregoryPetrochenkov-NOAA opened 1 year ago

GregoryPetrochenkov-NOAA commented 1 year ago

I am leveraging xarray spatial crosstab and there is an issue I am running in to when the data is originally a dataset. In this case I am using an extremely small example tiff file to load named 'arb.tif' (any would do for this example). When I load data like so there is no problem:

import rioxarray as rxr
from xrspatial.zonal import crosstab

c = rxr.open_rasterio('arb.tif', mask_and_scale=True, chunks="auto")
b = rxr.open_rasterio('arb.tif', mask_and_scale=True, chunks="auto")

carr = c.sel(band=1, drop=True)
barr = b.sel(band=1, drop=True)

df = crosstab(carr, barr)

However when I use band_as_variable=True to load the raster data as an Xarray Dataset I get a strange indexing problem:

import rioxarray as rxr
from xrspatial.zonal import crosstab

c = rxr.open_rasterio('arb.tif', mask_and_scale=True,
                      band_as_variable=True, chunks="auto")
b = rxr.open_rasterio('arb.tif', mask_and_scale=True,
                      band_as_variable=True, chunks="auto")

carr = c.to_array().sel(variable='band_1', drop=True)
barr = b.to_array().sel(variable='band_1', drop=True)

df = crosstab(carr, barr)

I get the following stack trace:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[18], line 12
      9 carr = c.to_array().sel(variable='band_1', drop=True)
     10 barr = b.to_array().sel(variable='band_1', drop=True)
---> 12 df = crosstab(carr2, barr2)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xrspatial/zonal.py:1065, in crosstab(zones, values, zone_ids, cat_ids, layer, agg, nodata_values)
   1053 unique_cats, cat_ids = _find_cats(values, cat_ids, nodata_values)
   1055 mapper = ArrayTypeFunctionMapping(
   1056     numpy_func=_crosstab_numpy,
   1057     dask_func=_crosstab_dask_numpy,
   (...)
   1063     ),
   1064 )
-> 1065 crosstab_df = mapper(values)(
   1066     zones.data, values.data,
   1067     zone_ids, unique_cats, cat_ids, nodata_values, agg
   1068 )
   1069 return crosstab_df

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xrspatial/zonal.py:840, in _crosstab_dask_numpy(zones, values, zone_ids, unique_cats, cat_ids, nodata_values, agg)
    829 crosstab_by_block = [
    830     _single_chunk_crosstab(
    831         z, v, unique_zones, zone_ids,
   (...)
    834     for z, v in zip(zones_blocks, values_blocks)
    835 ]
    837 crosstab_df = _crosstab_df_dask(
    838     crosstab_by_block, zone_ids, cat_ids, agg
    839 )
--> 840 return dd.from_delayed(crosstab_df)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/dataframe/io/io.py:624, in from_delayed(dfs, meta, divisions, prefix, verify_meta)
    621         raise TypeError("Expected Delayed object, got %s" % type(item).__name__)
    623 if meta is None:
--> 624     meta = delayed(make_meta)(dfs[0]).compute()
    625 else:
    626     meta = make_meta(meta)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/base.py:314, in DaskMethodsMixin.compute(self, **kwargs)
    290 def compute(self, **kwargs):
    291     """Compute this dask collection
    292 
    293     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    312     dask.base.compute
    313     """
--> 314     (result,) = compute(self, traverse=False, **kwargs)
    315     return result

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/base.py:599, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    596     keys.append(x.__dask_keys__())
    597     postcomputes.append(x.__dask_postcompute__())
--> 599 results = schedule(dsk, keys, **kwargs)
    600 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in <genexpr>(.0)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in <genexpr>(.0)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in <genexpr>(.0)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/dask/array/core.py:125, in getter(a, b, asarray, lock)
    120     # Below we special-case `np.matrix` to force a conversion to
    121     # `np.ndarray` and preserve original Dask behavior for `getter`,
    122     # as for all purposes `np.matrix` is array-like and thus
    123     # `is_arraylike` evaluates to `True` in that case.
    124     if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
--> 125         c = np.asarray(c)
    126 finally:
    127     if lock:

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xarray/core/indexing.py:464, in ImplicitToExplicitIndexingAdapter.__array__(self, dtype)
    463 def __array__(self, dtype=None):
--> 464     return np.asarray(self.array, dtype=dtype)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xarray/core/indexing.py:628, in CopyOnWriteArray.__array__(self, dtype)
    627 def __array__(self, dtype=None):
--> 628     return np.asarray(self.array, dtype=dtype)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xarray/core/indexing.py:529, in LazilyIndexedArray.__array__(self, dtype)
    527 def __array__(self, dtype=None):
    528     array = as_indexable(self.array)
--> 529     return np.asarray(array[self.key], dtype=None)

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/rioxarray/_io.py:423, in RasterioArrayWrapper.__getitem__(self, key)
    422 def __getitem__(self, key):
--> 423     return indexing.explicit_indexing_adapter(
    424         key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
    425     )

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/xarray/core/indexing.py:820, in explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method)
    798 """Support explicit indexing by delegating to a raw indexing method.
    799 
    800 Outer and/or vectorized indexers are supported by indexing a second time
   (...)
    817 Indexing result, in the form of a duck numpy-array.
    818 """
    819 raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
--> 820 result = raw_indexing_method(raw_key.tuple)
    821 if numpy_indices.tuple:
    822     # index the loaded np.ndarray
    823     result = NumpyIndexingAdapter(np.asarray(result))[numpy_indices]

File ~/miniconda3/envs/testpython38/lib/python3.8/site-packages/rioxarray/_io.py:420, in RasterioArrayWrapper._getitem(self, key)
    418 if squeeze_axis:
    419     out = np.squeeze(out, axis=squeeze_axis)
--> 420 return out[np_inds]

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

Expected behavior No index error and a crosstab DataFrame to be returned.

Desktop (please complete the following information):

Please advise.

brendancol commented 1 year ago

Hey @GregoryPetrochenkov-NOAA , sorry for late reply here and we are going to bump up the priority on reproducing / testing this issue here

brendancol commented 1 year ago

@GregoryPetrochenkov-NOAA

A potential solution could be to ensure that the arrays carr and barr are 2-dimensional before they are passed to the crosstab function via squeeze.

Here is a suggested modification to the code:

 import rioxarray as rxr
 from xrspatial.zonal import crosstab

 c = rxr.open_rasterio('arb.tif', mask_and_scale=True,
                       band_as_variable=True, chunks="auto")
 b = rxr.open_rasterio('arb.tif', mask_and_scale=True,
                       band_as_variable=True, chunks="auto")

 carr = c.to_array().sel(variable='band_1', drop=True).squeeze()
 barr = b.to_array().sel(variable='band_1', drop=True).squeeze()

 df = crosstab(carr, barr)

If the problem persists, it would be helpful to know the exact dimensions and shapes of carr and barr after they are created.

brendancol commented 1 year ago

@thuydotm Do we need better shape validation in crosstab? I'll add an issue to check