pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.62k stars 1.08k forks source link

πŸ› compatibility issues with ArrayAPI and SparseAPI Protocols in `namedarray` #8696

Open andersy005 opened 9 months ago

andersy005 commented 9 months ago

What happened?

i'm experiencing compatibility issues when using _arrayfunction_or_api and _sparsearrayfunction_or_api with the sparse arrays with dtype=object. specifically, runtime checks using isinstance with these protocols are failing, despite the sparse array object appearing to meet the necessary criteria (attributes and methods).

What did you expect to happen?

i expected that since COO arrays from the sparse library provide the necessary attributes and methods, they would pass the isinstance checks with the defined protocols.

In [56]: from xarray.namedarray._typing import _arrayfunction_or_api, _sparsearrayfunc
    ...: tion_or_api

In [57]: import xarray as xr, sparse, numpy as np, sparse, pandas as pd
In [58]: x = np.random.random((10))

In [59]: x[x < 0.9] = 0

In [60]: s = sparse.COO(x)

In [61]: isinstance(s, _arrayfunction_or_api)
Out[61]: True

In [62]: s
Out[62]: <COO: shape=(10,), dtype=float64, nnz=0, fill_value=0.0>
In [63]: p = sparse.COO(np.array(['a', 'b']))

In [64]: p
Out[64]: <COO: shape=(2,), dtype=<U1, nnz=2, fill_value=>

In [65]: isinstance(s, _arrayfunction_or_api)
Out[65]: True
In [66]: q = sparse.COO(np.array(['a', 'b']).astype(object))

In [67]: isinstance(s, _arrayfunction_or_api)
Out[67]: True

In [68]: isinstance(q, _arrayfunction_or_api)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:606, in _Elemwise._get_func_coords_data(self, mask)
    605 try:
--> 606     func_data = self.func(*func_args, dtype=self.dtype, **self.kwargs)
    607 except TypeError:

TypeError: real() got an unexpected keyword argument 'dtype'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:611, in _Elemwise._get_func_coords_data(self, mask)
    610     out = np.empty(func_args[0].shape, dtype=self.dtype)
--> 611     func_data = self.func(*func_args, out=out, **self.kwargs)
    612 except TypeError:

TypeError: real() got an unexpected keyword argument 'out'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[68], line 1
----> 1 isinstance(q, _arrayfunction_or_api)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/typing.py:1149, in _ProtocolMeta.__instancecheck__(cls, instance)
   1147     return True
   1148 if cls._is_protocol:
-> 1149     if all(hasattr(instance, attr) and
   1150             # All *methods* can be blocked by setting them to None.
   1151             (not callable(getattr(cls, attr, None)) or
   1152              getattr(instance, attr) is not None)
   1153             for attr in _get_protocol_attrs(cls)):
   1154         return True
   1155 return super().__instancecheck__(instance)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/typing.py:1149, in <genexpr>(.0)
   1147     return True
   1148 if cls._is_protocol:
-> 1149     if all(hasattr(instance, attr) and
   1150             # All *methods* can be blocked by setting them to None.
   1151             (not callable(getattr(cls, attr, None)) or
   1152              getattr(instance, attr) is not None)
   1153             for attr in _get_protocol_attrs(cls)):
   1154         return True
   1155 return super().__instancecheck__(instance)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_sparse_array.py:900, in SparseArray.real(self)
    875 @property
    876 def real(self):
    877     """The real part of the array.
    878 
    879     Examples
   (...)
    898     numpy.real : NumPy equivalent function.
    899     """
--> 900     return self.__array_ufunc__(np.real, "__call__", self)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_sparse_array.py:340, in SparseArray.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
    337     inputs = tuple(reversed(inputs_transformed))
    339 if method == "__call__":
--> 340     result = elemwise(ufunc, *inputs, **kwargs)
    341 elif method == "reduce":
    342     result = SparseArray._reduce(ufunc, *inputs, **kwargs)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:49, in elemwise(func, *args, **kwargs)
     12 def elemwise(func, *args, **kwargs):
     13     """
     14     Apply a function to any number of arguments.
     15 
   (...)
     46     it is necessary to convert Numpy arrays to :obj:`COO` objects.
     47     """
---> 49     return _Elemwise(func, *args, **kwargs).get_result()

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:480, in _Elemwise.get_result(self)
    477 if not any(mask):
    478     continue
--> 480 r = self._get_func_coords_data(mask)
    482 if r is not None:
    483     coords_list.append(r[0])

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:613, in _Elemwise._get_func_coords_data(self, mask)
    611         func_data = self.func(*func_args, out=out, **self.kwargs)
    612     except TypeError:
--> 613         func_data = self.func(*func_args, **self.kwargs).astype(self.dtype)
    615 unmatched_mask = ~equivalent(func_data, self.fill_value)
    617 if not unmatched_mask.any():

ValueError: invalid literal for int() with base 10: 'a'

In [69]: q
Out[69]: <COO: shape=(2,), dtype=object, nnz=2, fill_value=0>

the failing case appears to be a well know issue

Minimal Complete Verifiable Example

In [69]: q
Out[69]: <COO: shape=(2,), dtype=object, nnz=2, fill_value=0>

In [70]: n = xr.NamedArray(data=q, dims=['x'])

MVCE confirmation

Relevant log output

In [71]: n.data
Out[71]: <COO: shape=(2,), dtype=object, nnz=2, fill_value=0>

In [72]: n
Out[72]: ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:606, in _Elemwise._get_func_coords_data(self, mask)
    605 try:
--> 606     func_data = self.func(*func_args, dtype=self.dtype, **self.kwargs)
    607 except TypeError:

TypeError: real() got an unexpected keyword argument 'dtype'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:611, in _Elemwise._get_func_coords_data(self, mask)
    610     out = np.empty(func_args[0].shape, dtype=self.dtype)
--> 611     func_data = self.func(*func_args, out=out, **self.kwargs)
    612 except TypeError:

TypeError: real() got an unexpected keyword argument 'out'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj)
    701 stream = StringIO()
    702 printer = pretty.RepresentationPrinter(stream, self.verbose,
    703     self.max_width, self.newline,
    704     max_seq_length=self.max_seq_length,
    705     singleton_pprinters=self.singleton_printers,
    706     type_pprinters=self.type_printers,
    707     deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
    709 printer.flush()
    710 return stream.getvalue()

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~/devel/pydata/xarray/xarray/namedarray/core.py:987, in NamedArray.__repr__(self)
    986 def __repr__(self) -> str:
--> 987     return formatting.array_repr(self)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/reprlib.py:21, in recursive_repr.<locals>.decorating_function.<locals>.wrapper(self)
     19 repr_running.add(key)
     20 try:
---> 21     result = user_function(self)
     22 finally:
     23     repr_running.discard(key)

File ~/devel/pydata/xarray/xarray/core/formatting.py:665, in array_repr(arr)
    658     name_str = ""
    660 if (
    661     isinstance(arr, Variable)
    662     or _get_boolean_with_default("display_expand_data", default=True)
    663     or isinstance(arr.variable._data, MemoryCachedArray)
    664 ):
--> 665     data_repr = short_data_repr(arr)
    666 else:
    667     data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"])

File ~/devel/pydata/xarray/xarray/core/formatting.py:633, in short_data_repr(array)
    631 if isinstance(array, np.ndarray):
    632     return short_array_repr(array)
--> 633 elif isinstance(internal_data, _arrayfunction_or_api):
    634     return limit_lines(repr(array.data), limit=40)
    635 elif getattr(array, "_in_memory", None):

File ~/mambaforge/envs/xarray-tests/lib/python3.9/typing.py:1149, in _ProtocolMeta.__instancecheck__(cls, instance)
   1147     return True
   1148 if cls._is_protocol:
-> 1149     if all(hasattr(instance, attr) and
   1150             # All *methods* can be blocked by setting them to None.
   1151             (not callable(getattr(cls, attr, None)) or
   1152              getattr(instance, attr) is not None)
   1153             for attr in _get_protocol_attrs(cls)):
   1154         return True
   1155 return super().__instancecheck__(instance)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/typing.py:1149, in <genexpr>(.0)
   1147     return True
   1148 if cls._is_protocol:
-> 1149     if all(hasattr(instance, attr) and
   1150             # All *methods* can be blocked by setting them to None.
   1151             (not callable(getattr(cls, attr, None)) or
   1152              getattr(instance, attr) is not None)
   1153             for attr in _get_protocol_attrs(cls)):
   1154         return True
   1155 return super().__instancecheck__(instance)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_sparse_array.py:900, in SparseArray.real(self)
    875 @property
    876 def real(self):
    877     """The real part of the array.
    878 
    879     Examples
   (...)
    898     numpy.real : NumPy equivalent function.
    899     """
--> 900     return self.__array_ufunc__(np.real, "__call__", self)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_sparse_array.py:340, in SparseArray.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
    337     inputs = tuple(reversed(inputs_transformed))
    339 if method == "__call__":
--> 340     result = elemwise(ufunc, *inputs, **kwargs)
    341 elif method == "reduce":
    342     result = SparseArray._reduce(ufunc, *inputs, **kwargs)

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:49, in elemwise(func, *args, **kwargs)
     12 def elemwise(func, *args, **kwargs):
     13     """
     14     Apply a function to any number of arguments.
     15 
   (...)
     46     it is necessary to convert Numpy arrays to :obj:`COO` objects.
     47     """
---> 49     return _Elemwise(func, *args, **kwargs).get_result()

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:480, in _Elemwise.get_result(self)
    477 if not any(mask):
    478     continue
--> 480 r = self._get_func_coords_data(mask)
    482 if r is not None:
    483     coords_list.append(r[0])

File ~/mambaforge/envs/xarray-tests/lib/python3.9/site-packages/sparse/_umath.py:613, in _Elemwise._get_func_coords_data(self, mask)
    611         func_data = self.func(*func_args, out=out, **self.kwargs)
    612     except TypeError:
--> 613         func_data = self.func(*func_args, **self.kwargs).astype(self.dtype)
    615 unmatched_mask = ~equivalent(func_data, self.fill_value)
    617 if not unmatched_mask.any():

ValueError: invalid literal for int() with base 10: 'a'

Anything else we need to know?

i was trying to replace instances of is_duck_array with the protocol runtime checks (as part of https://github.com/pydata/xarray/pull/8319), and i've come to a realization that these runtime checks are rigid to accommodate the diverse behaviors of different array types, and is_duck_array() the function-based approach might be more manageable.

@Illviljan, are there any changes that could be made to both protocols without making them too complex?

Environment

```python INSTALLED VERSIONS ------------------ commit: 541049f45edeb518a767cb3b23fa53f6045aa508 python: 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:35:41) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.2.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.1.2.dev50+g78dec61f pandas: 2.2.0 numpy: 1.26.3 scipy: 1.12.0 netCDF4: 1.6.5 pydap: installed h5netcdf: 1.3.0 h5py: 3.10.0 Nio: None zarr: 2.16.1 cftime: 1.6.3 nc_time_axis: 1.4.1 iris: 3.7.0 bottleneck: 1.3.7 dask: 2024.1.1 distributed: 2024.1.1 matplotlib: 3.8.2 cartopy: 0.22.0 seaborn: 0.13.2 numbagg: 0.7.1 fsspec: 2023.12.2 cupy: None pint: 0.23 sparse: 0.15.1 flox: 0.9.0 numpy_groupies: 0.9.22 setuptools: 67.7.2 pip: 23.3.2 conda: None pytest: 8.0.0 mypy: 1.8.0 IPython: 8.14.0 sphinx: None ```
keewis commented 9 months ago

Note that on python 3.12 and higher this will actually start to work because it will not use hasattr anymore (which, by calling getattr, evaluates properties / descriptors).

See also #8605

Illviljan commented 9 months ago

Isn't this a bug in sparse?

import numpy as np
import numpy.array_api as xp

import sparse

a_np = np.array(['a', 'b']).astype(object)
a_sparse = sparse.COO(a_np)
print(a_np.real) # works
print(a_sparse.real) # ValueError: invalid literal for int() with base 10: 'a'

a_xp = xp.asarray(a_np) # TypeError: The array_api namespace does not support the dtype 'object'

is_duck_array looks like this: https://github.com/pydata/xarray/blob/c9ba2be2690564594a89eb93fb5d5c4ae7a9253c/xarray/core/utils.py#L262-L273 It is simply not true that namedarray only needs these for everything to work as namedarray use .real, .imag etc. The true minimum requirement is _arrayfunction_or_api.

By using _arrayfunction_or_api you will guarantee that mypy finds places where namedarray uses more methods than the requirement and this happens quite often as there are many places where we assume (incorrectly?) numpy arrays are used.