Regridding xarray dataset with chunked dask-backed arrays #222

zoj613 commented 1 year ago

Is there a way I can reliably regrid an xarray.Dataset object to a lower/higher resolution if it has variables with dask-backed chunked arrays. Every single time I try to use the output of the call to xesmf.Regridder to regrid the input data I get a

ValueError: Dimension 1 has 9 blocks, adjust_chunks specified with 1 blocks

exception. To get it to work, I have to force the datasets to have only a single chunk with .chunk(-1). This can cause tasks to fail when the dask graph is computed since a single chunk for large datasets can consume a lot of memory. Any workaround for this without using a single chunk?

aulemahal commented 1 year ago

Sadly, not for now. If your data has non-spatial dimensions, like time, I would suggest rechunking by merging all spatial chunks and splitting the time dimension so that the chunks sizes stay reasonable. However, the regridding weights are currently stored in a single-chunk sparse matrix and distributing this across the chunks of your data is a complex problem...

huard commented 1 year ago

If anyone has some ideas to solve this, that would be a great contribution.

huard commented 1 year ago

huard commented 1 year ago

dcherian commented 1 year ago

Here's how to do it

Read (or convert) weights as pydata/sparse

def read_xesmf_weights_file(filename):
    import numpy as np
    import sparse
    import xarray as xr

    weights = xr.open_dataset(filename)

    # input variable shape
    in_shape = weights.src_grid_dims.load().data

    # output variable shape
    out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

    print(f"Regridding from {in_shape} to {out_shape}")

    rows = weights['row'] - 1 # row indices (1-based)
    cols = weights['col'] - 1 # col indices (1-based)

    # construct a sparse array,
    # reshape to 3D : lat, lon, ncol
    # This reshaping should allow optional chunking along
    # lat, lon later
    sparse_array_data = sparse.COO(
        shape=(weights.sizes["n_b"], weights.sizes["n_a"]), 
      ).reshape((*out_shape, -1))

    # Create a DataArray with sparse weights and the output coordinates
    xsparse_wgts = xr.DataArray(
        dims=("lat", "lon", "ncol"),
        # Add useful coordinate information, this will get propagated to the output
            "lat": ("lat",[:, 0]),
            "lon": ("lon",[0, :]),
        # propagate useful information like regridding algorithm

    return xsparse_wgts

xsparse_wgts = read_xesmf_weights_file(map_path + map_file)

apply weights using opt_einsum

def apply_weights(dataset, weights):

    def _apply(da):
        # 🐵 🔧 
        xr.core.duck_array_ops.einsum  = opt_einsum.contract

        ans =
            # This dimension will be "contracted" 
            # or summmed over after multiplying by the weights

        # 🐵 🔧 : restore back to original
        xr.core.duck_array_ops.einsum = np.einsum

        return ans

    vars_with_ncol = [
        name for name, array  in dataset.variables.items() 
        if "ncol" in array.dims and name not in weights.coords
    regridded = dataset[vars_with_ncol].map(_apply)

    # merge in other variables, but skip those that are already set
    # like lat, lon
    return xr.merge([dataset.drop_vars(regridded.variables), regridded])

apply_weights(psfile, xsparse_wgts.chunk())


  1. It'll work with chunked inputs (both weights and da). The core piece is _apply.
  2. You can delete :O
  3. I wrote this for 1D unstructured -> 2D regridding, but it should work for even structured 2D->2D regridding.
  4. Directly using np.einsum like does by default doesn't work so well for chunked weights(bug report) but in my testing i also found opt_einsum to be a lot faster for plain (numpy data x sparse weights) image

See for the upstream issue to avoid the monkey-patch

aulemahal commented 1 year ago

