carbonplan / ndpyramid

A small utility for generating ND array pyramids using Xarray and Zarr.
https://ndpyramid.readthedocs.io
MIT License
95 stars 6 forks source link

Use xarray.map_blocks to speed up pyramid_reproject #10

Open jhamman opened 3 years ago

jhamman commented 3 years ago

@norlandrhagen and I have been experimenting with approaches for speeding up pyramid generation when using rio-xarray's reproject functionality. We have this rough prototype to share:

def pyramid_reproject(
    ds, levels: int = None, pixels_per_tile=128, resampling="average", extra_dim=None
) -> dt.DataTree:
    from rasterio.transform import Affine
    from rasterio.warp import Resampling

    # multiscales spec
    save_kwargs = {"levels": levels, "pixels_per_tile": pixels_per_tile}
    attrs = {
        "multiscales": _multiscales_template(
            datasets=[{"path": str(i) for i in range(levels)}],
            type="reduce",
            method="pyramid_reproject",
            version=_get_version(),
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    root = xr.Dataset(attrs=attrs)
    pyramid = dt.DataTree(data_objects={"root": root})

    def make_template(da, dim, dst_transform, shape=None):

        template = xr.DataArray(
            data=dask.array.empty(shape, chunks=shape), dims=("y", "x"), attrs=da.attrs
        )
        template = make_grid_ds(template, dim, dst_transform)
        template.coords["spatial_ref"] = xr.DataArray(np.array(1.0))
        return template

    def reproject(da, shape=None, dst_transform=None, resampling="average"):
        return da.rio.reproject(
            "EPSG:3857",
            resampling=Resampling[resampling],
            shape=shape,
            transform=dst_transform,
        )

    for level in range(levels):
        lkey = str(level)
        dim = 2 ** level * pixels_per_tile

        dst_transform = Affine.translation(-20026376.39, 20048966.10) * Affine.scale(
            (20026376.39 * 2) / dim, -(20048966.10 * 2) / dim
        )

        pyramid[lkey] = xr.Dataset(attrs=ds.attrs)
        for k, da in ds.items():
            template = make_template(ds[k], dim, dst_transform, (dim, dim))
            pyramid[lkey].ds[k] = xr.map_blocks(
                reproject,
                da,
                kwargs=dict(shape=(dim, dim), dst_transform=dst_transform),
                template=template,
            )

    return pyramid
dcherian commented 2 years ago

I thought @djhoese was working hard at dask aware reprojection in pyresample?

dcherian commented 2 years ago

Oh this is very cool!

https://github.com/pytroll/pyresample/blob/93df018c3a5cbdd3ae52766adf0b2cea04c5c019/pyresample/resampler.py#L153-L160

djhoese commented 2 years ago

"working hard" has mostly been in my head as I haven't had time for any of the "real" work. Luckily, I'm not the only one worried about this. @mraspaud did the work on that resample_blocks function and it looks like it might be a game changer for some of our algorithms. The basic idea is:

  1. Get the bounds of each target/output chunk.
  2. Slice the input array that covers this output chunk.
  3. Resample that input slice to the output chunk.

I initially wasn't a fan of this strategy as it requires slicing and rechunking of the input data, but @mraspaud's experience shows that it performs much better than resampling all overlapping input chunks and then merging/reducing them later.

dcherian commented 2 years ago

resampling all overlapping input chunks and then merging/reducing them later.

I think @gjoseph92 had something clever in stackstac for doing something like this and avoiding shuffling a bunch of NaNs (or other useless data) around. Can't find it now though. I might have totally misinterpreted though

djhoese commented 2 years ago

I have another very hacky implementation in pyresample for the "EWA" resampling algorithm (very specific to VIIRS and MODIS instruments) where I do a dask reduction but use tuples of values between functions. If the data is destined for the output chunk then the tuple contains arrays, if not then it contains Nones. It isn't how dask intends the function to be used (array functions should return arrays), but it works for me to prevent unnecessary processing of chunks that I know would be all NaNs/fills.