pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Incompatibilities with bloat16 after update to numpy 2 #9568

Open alvarosg opened 1 month ago

alvarosg commented 1 month ago

What happened?

Computing the max or the isnull on a DataArray with bfloat16 values raises a: TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

This worked fine before updating numpy to version 2. The main difference in the code seems to be that with numpy < 2, xarray uses its own implementation of isdtype, while for numpy >= 2 it relies on np.isdtype. This is confirmed by checking that doing import numpy as np; del np.isdtype fixes the problem.

What did you expect to happen?

I expected the computation to be successful, just as prior to numpy 2.

Minimal Complete Verifiable Example

import numpy as np
# del np.isdtype  # Uncommenting this line fixes it.

import xarray
import ml_dtypes

da = xarray.DataArray(np.zeros([2], dtype=ml_dtypes.bfloat16), dims=("dim",))
da.isnull() # Or da.max("dim")

MVCE confirmation

Relevant log output

TypeError                                 Traceback (most recent call last)
Cell In[1], line 5
      3 import numpy as np
      4 da = xarray.DataArray(np.zeros([2], dtype=jnp.bfloat16), dims=("dim",))
----> 5 da.isnull()

File ~/dev/xarray/xarray/core/common.py:1293, in DataWithCoords.isnull(self, keep_attrs)
   1290 if keep_attrs is None:
   1291     keep_attrs = _get_keep_attrs(default=False)
-> 1293 return apply_ufunc(
   1294     duck_array_ops.isnull,
   1295     self,
   1296     dask="allowed",
   1297     keep_attrs=keep_attrs,
   1298 )

File ~/dev/xarray/xarray/core/computation.py:1278, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
   1276 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1277 elif any(isinstance(a, DataArray) for a in args):
-> 1278     return apply_dataarray_vfunc(
   1279         variables_vfunc,
   1280         *args,
   1281         signature=signature,
   1282         join=join,
   1283         exclude_dims=exclude_dims,
   1284         keep_attrs=keep_attrs,
   1285     )
   1286 # feed Variables directly through apply_variable_ufunc
   1287 elif any(isinstance(a, Variable) for a in args):

File ~/dev/xarray/xarray/core/computation.py:320, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    315 result_coords, result_indexes = build_output_coords_and_indexes(
    316     args, signature, exclude_dims, combine_attrs=keep_attrs
    317 )
    319 data_vars = [getattr(a, "variable", a) for a in args]
--> 320 result_var = func(*data_vars)
    322 out: tuple[DataArray, ...] | DataArray
    323 if signature.num_outputs > 1:

File ~/dev/xarray/xarray/core/computation.py:831, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    826     if vectorize:
    827         func = _vectorize(
    828             func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
    829         )
--> 831 result_data = func(*input_data)
    833 if signature.num_outputs == 1:
    834     result_data = (result_data,)

File ~/dev/xarray/xarray/core/duck_array_ops.py:144, in isnull(data)
    139 if dtypes.is_datetime_like(scalar_type):
    140     # datetime types use NaT for null
    141     # note: must check timedelta64 before integers, because currently
    142     # timedelta64 inherits from np.integer
    143     return isnat(data)
--> 144 elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
    145     # float types use NaN for null
    146     xp = get_array_namespace(data)
    147     return xp.isnan(data)

File ~/dev/xarray/xarray/core/dtypes.py:208, in isdtype(dtype, kind, xp)
    205     raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
    207 if isinstance(dtype, np.dtype):
--> 208     return npcompat.isdtype(dtype, kind)
    209 elif is_extension_array_dtype(dtype):
    210     # we never want to match pandas extension array dtypes
    211     return False

File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/numpy/_core/numerictypes.py:425, in isdtype(dtype, kind)
    423     dtype = _preprocess_dtype(dtype)
    424 except _PreprocessDTypeError:
--> 425     raise TypeError(
    426         "dtype argument must be a NumPy dtype, "
    427         f"but it is a {type(dtype)}."
    428     ) from None
    430 input_kinds = kind if isinstance(kind, tuple) else (kind,)
    432 processed_kinds = set()

TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

Anything else we need to know?

Here's a a different reproducer showing the inconsistency between np.isdtype and npcompat.isdtype

import importlib
from xarray.core import npcompat
import ml_dtypes
import numpy as np
try:
  npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # `AttributeError: 'module' object has no attribute 'isdtype'`
except Exception as e:
  print(e)

numpy_is_dytype = np.isdtype
del np.isdtype
importlib.reload(npcompat)
np.isdtype = numpy_is_dytype

npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # No error, but returns False.

Environment

In [5]: xarray.show_versions() INSTALLED VERSIONS ------------------ commit: 03d3e0b5992051901c777cbf2c481abe2201facd python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.6.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.7.1.dev73+g781877cb pandas: 2.2.2 numpy: 2.1.1 scipy: 1.13.1 netCDF4: 1.6.5 pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: None dask: 2024.8.2 distributed: 2024.5.2 matplotlib: 3.9.0 cartopy: None seaborn: None numbagg: None fsspec: 2024.6.0 cupy: None pint: None sparse: None flox: None numpy_groupies: 0.11.1 setuptools: 70.0.0 pip: 24.0 conda: 24.7.1 pytest: 8.2.2 mypy: 1.10.0 IPython: 8.25.0
keewis commented 1 month ago

the difference here is that npcompat.isdtype translates the string to a numpy.dtype superclass, then uses isinstance to perform the check, while np.isdtype explicitly raises if it receives anything other than np.dtype subclasses or the string categories.

I don't think we can do a lot here (correct me if I'm wrong, @shoyer), so it might make more sense to take this up with the numpy devs.

cc @rgommers, @seberg for awareness

seberg commented 1 month ago

The quick thing is to use np.dtype() for conversion of the dtype (i.e. also in your code). I suspect np.isdtype (and other maybe other "array api" function) should do this explicitly. (I am not sure why it ends up where it ends up with the DType class.)

EDIT: To be clear, since this tries to use array API, I don't know that is possible to work around easily.

keewis commented 1 month ago

for reference, the reason this error is raised is because numpy._core._type_aliases.allTypes contains a explicit list of allowed dtypes, so any new dtypes that are not in that list – like ml_dtypes.bfloat16 or even the new numpy.dtypes.StringDType – when passed to numpy.isdtype will trigger this error.

Which means that wrapping in numpy.dtype does not help, unfortunately.

seberg commented 1 month ago

contains a explicit list of allowed dtypes

Can you open a NumPy issue about it? I know that there is always this knee jerk reaction to focus on the Array API blessed dtypes only, but honestly, that is just wrong. This is NumPy API and while there may be some guarantees missing, it shouldn't be artificially limiting here.

seberg commented 1 month ago

I have to look at what is going on closer. Maybe using this list was just a case of cargo-culted from the wrong place. Translating arbitrary objects to dtype instances is tricky.

alvarosg commented 1 month ago

In the meantime, would it make sense to simply continue falling back into the xarray implementationnpcompat.isdtype, even when np.isdtype it exists (instead of this try/except)? At the end of the day this is failing at an xarray callsite.

keewis commented 1 month ago

Can you open a NumPy issue about it?

See numpy/numpy#27545

In the meantime, would it make sense to simply continue falling back into the xarray implementation npcompat.isdtype

As npcompat is compatibility code that's supposed to go away as soon as we can require a specific numpy version I'd prefer waiting until the numpy team has reached a decision. However, we don't really have to wait until that change in numpy has been released to write the compat code.