pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.55k stars 1.06k forks source link

Support for Scipy Sparse Arrays #7280

Open mangecoeur opened 1 year ago

mangecoeur commented 1 year ago

What happened?

Now that Scipy is moving to support sparse NDarrays, we would expect that Xarray should work with them as any other array like data.

What did you expect to happen?

Doesn't work. It seems that why trying to use a scipy sparse array as the data, Xarray wraps the the sparse array in a 0-D dense array. (there are likely more issues after this but this was the first hurdle)

With sparse array s:

print(s)
<4x4 sparse array of type '<class 'numpy.float64'>'
    with 4 stored elements in COOrdinate format>
print(xr.DataArray(s).data)
array(<4x4 sparse array of type '<class 'numpy.float64'>'
    with 4 stored elements in COOrdinate format>, dtype=object)

Minimal Complete Verifiable Example

import numpy as np
import xarray as xr
from scipy.sparse import coo_array

row  = np.array([0, 3, 1, 0])

col  = np.array([0, 3, 1, 2])

data = np.array([4, 5.4, 7, 9.2])

s= coo_array((data, (row, col)), shape=(4, 4))
da = xr.DataArray(s)
print(da._repr_html_())

MVCE confirmation

Relevant log output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [4], in <cell line: 13>()
     11 s= coo_array((data, (row, col)), shape=(4, 4))
     12 da = xr.DataArray(s)
---> 13 print(da._repr_html_())

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/common.py:167, in AbstractArray._repr_html_(self)
    165 if OPTIONS["display_style"] == "text":
    166     return f"<pre>{escape(repr(self))}</pre>"
--> 167 return formatting_html.array_repr(self)

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting_html.py:311, in array_repr(arr)
    303 arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else ""
    305 header_components = [
    306     f"<div class='xr-obj-type'>{obj_type}</div>",
    307     f"<div class='xr-array-name'>{arr_name}</div>",
    308     format_dims(dims, indexed_dims),
    309 ]
--> 311 sections = [array_section(arr)]
    313 if hasattr(arr, "coords"):
    314     sections.append(coord_section(arr.coords))

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting_html.py:219, in array_section(obj)
    213 collapsed = (
    214     "checked"
    215     if _get_boolean_with_default("display_expand_data", default=True)
    216     else ""
    217 )
    218 variable = getattr(obj, "variable", obj)
--> 219 preview = escape(inline_variable_array_repr(variable, max_width=70))
    220 data_repr = short_data_repr_html(obj)
    221 data_icon = _icon("icon-database")

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting.py:274, in inline_variable_array_repr(var, max_width)
    272     return var._data._repr_inline_(max_width)
    273 if var._in_memory:
--> 274     return format_array_flat(var, max_width)
    275 dask_array_type = array_type("dask")
    276 if isinstance(var._data, dask_array_type):

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting.py:191, in format_array_flat(array, max_width)
    188 # every item will take up at least two characters, but we always want to
    189 # print at least first and last items
    190 max_possibly_relevant = min(max(array.size, 1), max(math.ceil(max_width / 2.0), 2))
--> 191 relevant_front_items = format_items(
    192     first_n_items(array, (max_possibly_relevant + 1) // 2)
    193 )
    194 relevant_back_items = format_items(last_n_items(array, max_possibly_relevant // 2))
    195 # interleave relevant front and back items:
    196 #     [a, b, c] and [y, z] -> [a, z, b, y, c]

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting.py:180, in format_items(x)
    177     elif np.logical_not(time_needed).all():
    178         timedelta_format = "date"
--> 180 formatted = [format_item(xi, timedelta_format) for xi in x]
    181 return formatted

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting.py:180, in <listcomp>(.0)
    177     elif np.logical_not(time_needed).all():
    178         timedelta_format = "date"
--> 180 formatted = [format_item(xi, timedelta_format) for xi in x]
    181 return formatted

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/xarray/core/formatting.py:161, in format_item(x, timedelta_format, quote_strings)
    159     return repr(x) if quote_strings else x
    160 elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating):
--> 161     return f"{x.item():.4}"
    162 else:
    163     return str(x)

File ~/Scratch/.conda/envs/tessa-1/lib/python3.10/site-packages/scipy/sparse/_base.py:771, in spmatrix.__getattr__(self, attr)
    769     return self.getnnz()
    770 else:
--> 771     raise AttributeError(attr + " not found")

AttributeError: item not found

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] python-bits: 64 OS: Linux OS-release: 5.13.0-41-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.12.2 libnetcdf: 4.8.1 xarray: 2022.11.0 pandas: 1.4.3 numpy: 1.22.4 scipy: 1.9.0 netCDF4: 1.6.0 pydap: None h5netcdf: None h5py: 3.7.0 Nio: None zarr: 2.12.0 cftime: 1.6.1 nc_time_axis: None PseudoNetCDF: None rasterio: 1.3.2 cfgrib: 0.9.10.1 iris: None bottleneck: 1.3.5 dask: 2022.8.1 distributed: 2022.8.1 matplotlib: 3.5.3 cartopy: 0.20.3 seaborn: 0.11.2 numbagg: None fsspec: 2022.7.1 cupy: None pint: 0.19.2 sparse: 0.13.0 flox: None numpy_groupies: None setuptools: 65.2.0 pip: 22.2.2 conda: 4.14.0 pytest: 7.1.2 IPython: 8.4.0 sphinx: None
keewis commented 1 year ago

the reason that's failing is that scipy.sparse.coo_array does not implement __array_function__ or __array_namespace__. As soon as that's the case I'd expect it to work.

In fact, setting s.__array_namespace__ = lambda x: x will result in xr.DataArray(s) printing as

<xarray.DataArray (dim_0: 4, dim_1: 4)>
<4x4 sparse array of type '<class 'numpy.float64'>'
    with 4 stored elements in COOrdinate format>
Dimensions without coordinates: dim_0, dim_1

which I think is what you were after? Though of course that does not mean that we can actually use it...

mangecoeur commented 1 year ago

Ok I had assumed that scipy would have directly implemented the array interface, I will see if there is already an issue open there. Then we can slowly see what else does/doesn't work.

mangecoeur commented 1 year ago

@keewis using your solution things seem to more or less work, except that every operation of course 'loses' the __array_namespace__ attr so anything like slicing only half works, plus a lot of indexing operations are not implemented on scipy sparse arrays.

keewis commented 1 year ago

changing the assignment to:

s.__class__.__array_namespace__ = ...

should fix this, but indeed it is better to just wait on scipy.sparse to implement the array API (NEP47)