xarray-contrib / xarray-regrid

Regridding utility for xarray
https://xarray-regrid.readthedocs.org/
Apache License 2.0
71 stars 7 forks source link

NaN threshold for conservative method #39

Closed slevang closed 2 months ago

slevang commented 6 months ago

Nice package! In testing it out where I've previously used xesmf, I noticed two features lacking from the conservative method:

  1. the regridded dataset can't be constructed lazily if dask-backed due to this line
  2. no ability to keep target cells where the input points are partially NaN, as noted in #32

Number 1 is easy, number 2 is trickier. I added a naive implementation for the nan_threshold capabilities of xesmf here for discussion. As noted in the previous issue, to do this 100% correctly we would need to track the NaN fraction as we reduce over each dimension, which I'm not doing here. The nan_threshold value doesn't translate directly to total fraction of NaN cells due to the sequential reduction. It would also get complicated for isolated NaNs in the temporal dimension.

I'm not sure any of this matters much for a dataset where you have consistent NaN's e.g. SST. Here's an example of the new functionality used on the MUR dataset. Note this is a 33TB array but we can now generate the (lazily) regridded dataset instantaneously.

import xarray as xr
import xarray_regrid

sst = xr.open_zarr("https://mur-sst.s3.us-west-2.amazonaws.com/zarr-v1").analysed_sst
grid = xarray_regrid.Grid(
    north=90,
    east=180,
    south=-90,
    west=-180,
    resolution_lat=1,
    resolution_lon=1,
)
target = xarray_regrid.create_regridding_dataset(grid)
target = target.rename(latitude="lat", longitude="lon")

ds0p0 = sst.regrid.conservative(target, nan_threshold=0)
ds0p5 = sst.regrid.conservative(target, nan_threshold=0.5)
ds1p0 = sst.regrid.conservative(target, nan_threshold=1)

ds0p0

ds0p5

ds1p0

slevang commented 6 months ago

Thanks, and yes for sure, hopefully next week I will find time to add tests and benchmarks.

slevang commented 6 months ago

Did some quick profiling on a ~4GB array of 1/4deg global data coarsening to 1deg. Dask array on a 32 CPU node. Results:

So adding skipna forces roughly one additional pass through the array with the weight renormalization. The reason this PR is faster than main is because the current code has the np.any(np.isnan()) check which forces computation, plus the separately calculated isnan array, which forces 3 passes through the data. If I cut out the logic branch of checking for NaNs on main and go straight to the einsum, we recover the ~32s run above.

slevang commented 4 months ago

Made the modification to take notnull.any(non_regrid_dims) which leaves us at about a 3x performance penalty for skipna=True in the benchmarks I've run. I think this should maybe be a configurable arg though in cases where you want to track NaNs very carefully throughout the dataset.

BSchilperoort commented 2 months ago

Merged as part of #41