pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.59k stars 1.08k forks source link

`ds.interp()` breaks if (non-interpolating) dimension is not numeric #8058

Open ks905383 opened 1 year ago

ks905383 commented 1 year ago

What happened?

I'm running ds.interp() using multi-dimensional new coordinates, using xarray's broadcasting to expand the original dataset to new dimensions. In this case, I'm only interpolating on one dimension, but broadcasting out to others.

If the dimensions are all numeric (or, presumably, able to be forced to numeric), then this works without an issue. However, if one of the other dimensions is, e.g., populated with string indices (weather station names, model run ids, etc.), then this process fails, even if the dimension on which the interpolating is conducted is purely numeric.

What did you expect to happen?

Here is an example with only numeric dimensions that works as expected:

import xarray as xr
import numpy as np

da1 = xr.DataArray(np.reshape(np.arange(0,12),(3,4)),
                  coords = {'dim0':np.arange(0,3),
                            'dim1':np.arange(0,4)})

da2 = xr.DataArray(np.random.normal(loc=1,size=(2,4),scale=0.5),
                   coords = {'dim2':np.arange(0,2),
                             'dim1':np.arange(0,4)})

da1.interp(dim0=da2)

this produces something like:

Bildschirmfoto 2023-08-08 um 4 31 01 PM

as expected.

Minimal Complete Verifiable Example

import xarray as xr
import numpy as np

da1 = xr.DataArray(np.reshape(np.arange(0,12),(3,4)),
                  coords = {'dim0':np.arange(0,3),
                            'dim1':np.arange(0,4).astype(str)})

da2 = xr.DataArray(np.random.normal(loc=1,size=(2,4),scale=0.5),
                   coords = {'dim2':np.arange(0,2),
                             'dim1':np.arange(0,4).astype(str)})

da1.interp(dim0=da2)

MVCE confirmation

Relevant log output

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[48], line 9
      1 da1 = xr.DataArray(np.reshape(np.arange(0,12),(3,4)),
      2                   coords = {'dim0':np.arange(0,3),
      3                             'dim1':np.arange(0,4).astype(str)})
      5 da2 = xr.DataArray(np.random.normal(loc=1,size=(2,4),scale=0.5),
      6                    coords = {'dim2':np.arange(0,2),
      7                              'dim1':np.arange(0,4).astype(str)})
----> 9 da1.interp(dim0=da2)

File ~/.conda/envs/climate/lib/python3.10/site-packages/xarray/core/dataarray.py:2204, in DataArray.interp(self, coords, method, assume_sorted, kwargs, **coords_kwargs)
   2199 if self.dtype.kind not in "uifc":
   2200     raise TypeError(
   2201         "interp only works for a numeric type array. "
   2202         "Given {}.".format(self.dtype)
   2203     )
-> 2204 ds = self._to_temp_dataset().interp(
   2205     coords,
   2206     method=method,
   2207     kwargs=kwargs,
   2208     assume_sorted=assume_sorted,
   2209     **coords_kwargs,
   2210 )
   2211 return self._from_temp_dataset(ds)

File ~/.conda/envs/climate/lib/python3.10/site-packages/xarray/core/dataset.py:3666, in Dataset.interp(self, coords, method, assume_sorted, kwargs, method_non_numeric, **coords_kwargs)
   3664 if method in ["linear", "nearest"]:
   3665     for k, v in validated_indexers.items():
-> 3666         obj, newidx = missing._localize(obj, {k: v})
   3667         validated_indexers[k] = newidx[k]
   3669 # optimization: create dask coordinate arrays once per Dataset
   3670 # rather than once per Variable when dask.array.unify_chunks is called later
   3671 # GH4739

File ~/.conda/envs/climate/lib/python3.10/site-packages/xarray/core/missing.py:562, in _localize(var, indexes_coords)
    560 indexes = {}
    561 for dim, [x, new_x] in indexes_coords.items():
--> 562     minval = np.nanmin(new_x.values)
    563     maxval = np.nanmax(new_x.values)
    564     index = x.to_index()

File <__array_function__ internals>:5, in nanmin(*args, **kwargs)

File ~/.conda/envs/climate/lib/python3.10/site-packages/numpy/lib/nanfunctions.py:319, in nanmin(a, axis, out, keepdims)
    315     kwargs['keepdims'] = keepdims
    316 if type(a) is np.ndarray and a.dtype != np.object_:
    317     # Fast, but not safe for subclasses of ndarray, or object arrays,
    318     # which do not implement isnan (gh-9009), or fmin correctly (gh-8975)
--> 319     res = np.fmin.reduce(a, axis=axis, out=out, **kwargs)
    320     if np.isnan(res).any():
    321         warnings.warn("All-NaN slice encountered", RuntimeWarning,
    322                       stacklevel=3)

TypeError: cannot perform reduce with flexible type

Anything else we need to know?

I'm pretty sure the issue is in this optimization step.

It calls _localize() from missing.py, which calls np.nanmin() and np.nanmax() on all the coordinates, including the ones that aren't used in the interpolation, but only in the broadcasting.

Perhaps a way to fix this would be to have a test in localize for numeric indices, and then only subset the numeric dimensions? (I could see generalizing _localize() to other data types may be more trouble than it's worth, especially for unsorted string dimensions...) Or only subset the dimensions used in the interpolation itself? Or, alternatively, having a way to turn off optimizations like this?

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0] python-bits: 64 OS: Linux OS-release: 3.10.0-1160.76.1.el7.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: (None, None) libhdf5: 1.12.1 libnetcdf: 4.8.1 xarray: 2023.7.0 pandas: 1.4.1 numpy: 1.21.6 scipy: 1.11.1 netCDF4: 1.5.8 pydap: None h5netcdf: None h5py: None Nio: 1.5.5 zarr: 2.13.2 cftime: 1.6.2 nc_time_axis: 1.4.1 PseudoNetCDF: None iris: None bottleneck: 1.3.7 dask: 2023.3.0 distributed: 2023.3.0 matplotlib: 3.5.1 cartopy: 0.20.2 seaborn: 0.11.2 numbagg: None fsspec: 2022.5.0 cupy: None pint: 0.22 sparse: 0.14.0 flox: None numpy_groupies: None setuptools: 68.0.0 pip: 23.2.1 conda: None pytest: 7.0.1 mypy: None IPython: 8.14.0 sphinx: None
welcome[bot] commented 1 year ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!