pangeo-data / xESMF

Universal Regridder for Geospatial Data
http://xesmf.readthedocs.io/
MIT License
183 stars 32 forks source link

Regridding xarray dataset with chunked dask-backed arrays #222

Closed zoj613 closed 11 months ago

zoj613 commented 1 year ago

Is there a way I can reliably regrid an xarray.Dataset object to a lower/higher resolution if it has variables with dask-backed chunked arrays. Every single time I try to use the output of the call to xesmf.Regridder to regrid the input data I get a

ValueError: Dimension 1 has 9 blocks, adjust_chunks specified with 1 blocks

exception. To get it to work, I have to force the datasets to have only a single chunk with .chunk(-1). This can cause tasks to fail when the dask graph is computed since a single chunk for large datasets can consume a lot of memory. Any workaround for this without using a single chunk?

aulemahal commented 1 year ago

Sadly, not for now. If your data has non-spatial dimensions, like time, I would suggest rechunking by merging all spatial chunks and splitting the time dimension so that the chunks sizes stay reasonable. However, the regridding weights are currently stored in a single-chunk sparse matrix and distributing this across the chunks of your data is a complex problem...

huard commented 1 year ago

If anyone has some ideas to solve this, that would be a great contribution.

huard commented 1 year ago

See https://discourse.pangeo.io/t/conservative-region-aggregation-with-xarray-geopandas-and-sparse/2715 for possible solution.

huard commented 1 year ago

We're hoping to have an intern work on this next summer. If anyone has tips to share, please leave them here.

dcherian commented 1 year ago

Here's how to do it

Read (or convert) weights as pydata/sparse

def read_xesmf_weights_file(filename):
    import numpy as np
    import sparse
    import xarray as xr

    weights = xr.open_dataset(filename)

    # input variable shape
    in_shape = weights.src_grid_dims.load().data

    # output variable shape
    out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

    print(f"Regridding from {in_shape} to {out_shape}")

    rows = weights['row'] - 1 # row indices (1-based)
    cols = weights['col'] - 1 # col indices (1-based)

    # construct a sparse array,
    # reshape to 3D : lat, lon, ncol
    # This reshaping should allow optional chunking along
    # lat, lon later
    sparse_array_data = sparse.COO(
        coords=np.stack([rows.data, cols.data]), 
        data=weights.S.data, 
        shape=(weights.sizes["n_b"], weights.sizes["n_a"]), 
        fill_value=0,
      ).reshape((*out_shape, -1))

    # Create a DataArray with sparse weights and the output coordinates
    xsparse_wgts = xr.DataArray(
        sparse_array_data,
        dims=("lat", "lon", "ncol"),
        # Add useful coordinate information, this will get propagated to the output
        coords={
            "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
            "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
        },
        # propagate useful information like regridding algorithm
        attrs=weights.attrs,
    )

    return xsparse_wgts

xsparse_wgts = read_xesmf_weights_file(map_path + map_file)

apply weights using opt_einsum

https://dgasmith.github.io/opt_einsum/

def apply_weights(dataset, weights):

    def _apply(da):
        # 🐵 🔧 
        xr.core.duck_array_ops.einsum  = opt_einsum.contract

        ans = xr.dot(
            da, 
            weights, 
            # This dimension will be "contracted" 
            # or summmed over after multiplying by the weights
            dims="ncol",
        )

        # 🐵 🔧 : restore back to original
        xr.core.duck_array_ops.einsum = np.einsum

        return ans

    vars_with_ncol = [
        name for name, array  in dataset.variables.items() 
        if "ncol" in array.dims and name not in weights.coords
    ]
    regridded = dataset[vars_with_ncol].map(_apply)

    # merge in other variables, but skip those that are already set
    # like lat, lon
    return xr.merge([dataset.drop_vars(regridded.variables), regridded])

apply_weights(psfile, xsparse_wgts.chunk())

Gainzzzz

  1. It'll work with chunked inputs (both weights and da). The core piece is _apply.
  2. You can delete smm.py :O
  3. I wrote this for 1D unstructured -> 2D regridding, but it should work for even structured 2D->2D regridding.
  4. Directly using np.einsum like xr.dot does by default doesn't work so well for chunked weights(bug report) but in my testing i also found opt_einsum to be a lot faster for plain (numpy data x sparse weights) image

See https://github.com/pydata/xarray/issues/7764 for the upstream issue to avoid the monkey-patch

aulemahal commented 1 year ago

@charlesgauthier-udm Here's the "parallelize the application" issue.