Closed zoj613 closed 11 months ago
Sadly, not for now.
If your data has non-spatial dimensions, like time
, I would suggest rechunking by merging all spatial chunks and splitting the time dimension so that the chunks sizes stay reasonable.
However, the regridding weights are currently stored in a single-chunk sparse matrix and distributing this across the chunks of your data is a complex problem...
If anyone has some ideas to solve this, that would be a great contribution.
See https://discourse.pangeo.io/t/conservative-region-aggregation-with-xarray-geopandas-and-sparse/2715 for possible solution.
We're hoping to have an intern work on this next summer. If anyone has tips to share, please leave them here.
Here's how to do it
def read_xesmf_weights_file(filename):
import numpy as np
import sparse
import xarray as xr
weights = xr.open_dataset(filename)
# input variable shape
in_shape = weights.src_grid_dims.load().data
# output variable shape
out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]
print(f"Regridding from {in_shape} to {out_shape}")
rows = weights['row'] - 1 # row indices (1-based)
cols = weights['col'] - 1 # col indices (1-based)
# construct a sparse array,
# reshape to 3D : lat, lon, ncol
# This reshaping should allow optional chunking along
# lat, lon later
sparse_array_data = sparse.COO(
coords=np.stack([rows.data, cols.data]),
data=weights.S.data,
shape=(weights.sizes["n_b"], weights.sizes["n_a"]),
fill_value=0,
).reshape((*out_shape, -1))
# Create a DataArray with sparse weights and the output coordinates
xsparse_wgts = xr.DataArray(
sparse_array_data,
dims=("lat", "lon", "ncol"),
# Add useful coordinate information, this will get propagated to the output
coords={
"lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
"lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
},
# propagate useful information like regridding algorithm
attrs=weights.attrs,
)
return xsparse_wgts
xsparse_wgts = read_xesmf_weights_file(map_path + map_file)
https://dgasmith.github.io/opt_einsum/
def apply_weights(dataset, weights):
def _apply(da):
# 🐵 🔧
xr.core.duck_array_ops.einsum = opt_einsum.contract
ans = xr.dot(
da,
weights,
# This dimension will be "contracted"
# or summmed over after multiplying by the weights
dims="ncol",
)
# 🐵 🔧 : restore back to original
xr.core.duck_array_ops.einsum = np.einsum
return ans
vars_with_ncol = [
name for name, array in dataset.variables.items()
if "ncol" in array.dims and name not in weights.coords
]
regridded = dataset[vars_with_ncol].map(_apply)
# merge in other variables, but skip those that are already set
# like lat, lon
return xr.merge([dataset.drop_vars(regridded.variables), regridded])
apply_weights(psfile, xsparse_wgts.chunk())
weights
and da
). The core piece is _apply
. smm.py
:Onp.einsum
like xr.dot
does by default doesn't work so well for chunked weights
(bug report) but in my testing i also found opt_einsum
to be a lot faster for plain (numpy data x sparse weights)
See https://github.com/pydata/xarray/issues/7764 for the upstream issue to avoid the monkey-patch
@charlesgauthier-udm Here's the "parallelize the application" issue.
Is there a way I can reliably regrid an xarray.Dataset object to a lower/higher resolution if it has variables with dask-backed chunked arrays. Every single time I try to use the output of the call to
xesmf.Regridder
to regrid the input data I get aexception. To get it to work, I have to force the datasets to have only a single chunk with
.chunk(-1)
. This can cause tasks to fail when the dask graph is computed since a single chunk for large datasets can consume a lot of memory. Any workaround for this without using a single chunk?