pangeo-data / xESMF

Universal Regridder for Geospatial Data
http://xesmf.readthedocs.io/
MIT License
195 stars 35 forks source link

Spatial chunking no longer works #283

Closed yt87 closed 1 year ago

yt87 commented 1 year ago

The example in the doc: https://xesmf.readthedocs.io/en/latest/notebooks/Dask.html fails for spatial chunking. I combined the code from all the IPython cells:

import os

prefix = f"{os.path.dirname(os.environ['CONDA_PREFIX'])}/icec"
os.environ["ESMFMKFILE"] = f"{prefix}/lib/esmf.mk"
os.environ["PROJ_DATA"] = f"{prefix}/share/proj"

import numpy as np
import dask.array as da  # need to have dask.array installed, although not directly using it here.
import xarray as xr
import xesmf as xe

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 500})
ds
ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(16, 75, 1.0)),
        "lon": (["lon"], np.arange(200, 330, 1.5)),
    }
)

regridder = xe.Regridder(ds, ds_out, "bilinear")
regridder
%time ds_out = regridder(ds)
ds_out
ds_out["air"].data
%time result = ds_out['air'].compute()
'Spatial rechunking'
ds_spatial = ds.chunk({"lat": 25, "lon": 25, "time": -1})
ds_spatial
# Fails here
ds_spatial_out = regridder(ds_spatial)  # Regridding ds_spatial
ds_spatial_out["air"].data
ds_spatial_out = regridder(ds_spatial, output_chunks={"lat": 10, "lon": 10})
ds_spatial_out["air"].data

The result is an exception:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 31
     29 ds_spatial
     30 # Fails here
---> 31 ds_spatial_out = regridder(ds_spatial)  # Regridding ds_spatial
     32 ds_spatial_out["air"].data
     33 ds_spatial_out = regridder(ds_spatial, output_chunks={"lat": 10, "lon": 10})

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xesmf/frontend.py:463, in BaseRegridder.__call__(self, indata, keep_attrs, skipna, na_thres)
    459     return self.regrid_dataarray(
    460         indata, keep_attrs=keep_attrs, skipna=skipna, na_thres=na_thres
    461     )
    462 elif isinstance(indata, xr.Dataset):
--> 463     return self.regrid_dataset(
    464         indata, keep_attrs=keep_attrs, skipna=skipna, na_thres=na_thres
    465     )
    466 else:
    467     raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!')

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xesmf/frontend.py:567, in BaseRegridder.regrid_dataset(self, ds_in, keep_attrs, skipna, na_thres)
    560 non_regriddable = [
    561     name
    562     for name, data in ds_in.data_vars.items()
    563     if not set(input_horiz_dims).issubset(data.dims)
    564 ]
    565 ds_in = ds_in.drop_vars(non_regriddable)
--> 567 ds_out = xr.apply_ufunc(
    568     self.regrid_array,
    569     ds_in,
    570     self.weights,
    571     kwargs=kwargs,
    572     input_core_dims=[input_horiz_dims, ('out_dim', 'in_dim')],
    573     output_core_dims=[temp_horiz_dims],
    574     dask='allowed',
    575     keep_attrs=keep_attrs,
    576 )
    578 return self._format_xroutput(ds_out, temp_horiz_dims)

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xarray/core/computation.py:1185, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1183 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
   1184 elif any(is_dict_like(a) for a in args):
-> 1185     return apply_dataset_vfunc(
   1186         variables_vfunc,
   1187         *args,
   1188         signature=signature,
   1189         join=join,
   1190         exclude_dims=exclude_dims,
   1191         dataset_join=dataset_join,
   1192         fill_value=dataset_fill_value,
   1193         keep_attrs=keep_attrs,
   1194     )
   1195 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1196 elif any(isinstance(a, DataArray) for a in args):

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xarray/core/computation.py:469, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
    464 list_of_coords, list_of_indexes = build_output_coords_and_indexes(
    465     args, signature, exclude_dims, combine_attrs=keep_attrs
    466 )
    467 args = tuple(getattr(arg, "data_vars", arg) for arg in args)
--> 469 result_vars = apply_dict_of_variables_vfunc(
    470     func, *args, signature=signature, join=dataset_join, fill_value=fill_value
    471 )
    473 out: Dataset | tuple[Dataset, ...]
    474 if signature.num_outputs > 1:

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xarray/core/computation.py:411, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
    409 result_vars = {}
    410 for name, variable_args in zip(names, grouped_by_name):
--> 411     result_vars[name] = func(*variable_args)
    413 if signature.num_outputs > 1:
    414     return _unpack_dict_tuples(result_vars, signature.num_outputs)

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xarray/core/computation.py:761, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    756     if vectorize:
    757         func = _vectorize(
    758             func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
    759         )
--> 761 result_data = func(*input_data)
    763 if signature.num_outputs == 1:
    764     result_data = (result_data,)

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/xesmf/frontend.py:507, in BaseRegridder.regrid_array(self, indata, weights, skipna, na_thres)https://github.com/dask/dask/issues/4299
    504     weights = da.from_array(weights, chunks=weights.shape)
    505     output_chunks = indata.chunks[:-2] + ((self.shape_out[0],), (self.shape_out[1],))
--> 507     outdata = da.map_blocks(
    508         self._regrid,
    509         indata,
    510         weights,
    511         dtype=indata.dtype,
    512         chunks=output_chunks,
    513         meta=np.array((), dtype=indata.dtype),
    514         **kwargs,
    515     )
    516 else:  # numpy
    517     outdata = self._regrid(indata, weights, **kwargs)

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/dask/array/core.py:872, in map_blocks(func, name, token, dtype, chunks, drop_axis, new_axis, enforce_ndim, meta, *args, **kwargs)
    856     out = blockwise(
    857         apply_and_enforce,
    858         out_ind,
   (...)
    869         **kwargs,
    870     )
    871 else:
--> 872     out = blockwise(
    873         func,
    874         out_ind,
    875         *concat(argpairs),
    876         name=name,
    877         new_axes=new_axes,
    878         dtype=dtype,
    879         concatenate=True,
    880         align_arrays=False,
    881         adjust_chunks=adjust_chunks,
    882         meta=meta,
    883         **kwargs,
    884     )
    886 extra_argpairs = []
    887 extra_names = []

File ~/mambaforge/envs/icec/lib/python3.10/site-packages/dask/array/blockwise.py:272, in blockwise(func, out_ind, name, token, dtype, adjust_chunks, new_axes, align_arrays, concatenate, meta, *args, **kwargs)
    270 elif isinstance(adjust_chunks[ind], (tuple, list)):
    271     if len(adjust_chunks[ind]) != len(chunks[i]):
--> 272         raise ValueError(
    273             f"Dimension {i} has {len(chunks[i])} blocks, adjust_chunks "
    274             f"specified with {len(adjust_chunks[ind])} blocks"
    275         )
    276     chunks[i] = tuple(adjust_chunks[ind])
    277 else:

ValueError: Dimension 2 has 3 blocks, adjust_chunks specified with 1 blocks

This issue seems to be related to #https://github.com/dask/dask/issues/4299

My environment:

python: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0]
dask: 2023.7.0
numpy: 1.24.4
xarray: 2023.6.0
xesmf: 0.7.1
aulemahal commented 1 year ago

Hi @yt87 !

I think this is because you are using xESMF 0.7.1, which doesn't support spatial chunks when regridding. I believe the confusion comes from ReadTheDocs : the link you provided is the documentation of the "latest" xESMF, which means the "master branch" here on github, which does have this new feature.

Our doc setup might be broken because the "stable" version (0.7.1) seems not accessible... I'll open an issue.

One solution for your issue is to install xESMF from source, but beware that there might be some issues that haven't been fixed yet.

yt87 commented 1 year ago

Thanks. Another solution, which I chose, is to rechunk the dataset before invoking regridder. Regarding documentation, maybe you could use the numpy and python practice: add a note that a particular feature is introduced in version x.y.z