WFP-VAM / hdc-algo

Algorithmic code used in the HumanitarianDataCube
MIT License
0 stars 1 forks source link

Improvements to chunking / chunk based computation for algorithmic purposes #51

Open valpesendorfer opened 1 week ago

valpesendorfer commented 1 week ago

TLDR:

Most pixel based hdc-algo functions require the full time dimension in memory, even if it's not used in the computation itself. Can this be improved?

Background

In general, most hdc-algo functions are applied on 3d spatiotemporal cubes, which calculate outputs for each pixel across the time dimension and generate output which is either the same shape and dimensionality of the input (these are most cases) or performs a reduction across the time dimension.

Examples for the first case, where

f(A_{(m,n,p)}) = B_{(m,n,p)}

are

examples where function generates a change in the size of the last dimension

f(A_{(m,n,p)}) = B_{(m,n,p')}

are

and an example here the the third dimension is reduced entirely

f(A_{(m,n,p)}) = B_{(m,n)}

is the lag-1 autocorrelation

All of these functions were designed to work with dask-backed lazy arrays through xarray.apply_ufunc (except for zonal mean), with the downside that they require the time dimension to be a single chunk, i.e. chunking is only possible in the x/y dimensions.

New releases of hdc-algo, specifically 0.3.0 and the improvements introduced in 0.5.0 added a few simplifications for the user, specifically aimed on computation on dekad basis:

An example notebook for these additions can be seen here: https://jupyter.earthobservation.vam.wfp.org/public/examples/grouped-operations.html

Issue

Requiring the full time dimension in memory at time of computation is (assumed to be) a major blocker for scalability of these computations. Also, for most of these computations we have prior knowledge of what's being computed and what is required for these computation - information that's currently not used but could inform a more bespoke chunking & computation.

Questions / Ideas

Can we somehow be smarter with chunking and computing the outputs to make the computations scale better, incorporating some of the knowledge about the computation and creating blocks of data which only contain data needed at time of computation? The idea would be to have concise chunks which then can be iterated upon by using for instance map_blocks rather than using apply_ufunc across the full dimension.

For example:

But, how would that work for moving / rolling window functions?

Kirill888 commented 4 days ago

So spatially this is always per pixel, however along time dimension we have a variety of input to output dependencies

What would be the largest (memory wise) input collection?

One way to approach this without re-chunking is to do a more manual dask graph that accepts all required temporal chunks for a given output and stack them before calling processing method. So something like:

@delayed
def apply_to_bunch(func: Callable, *blocks):
    data = stack(*blocks)
    return func(data)

This will likely result in peak memory usage being 2 times that of input plus output size, but we can reduce it to something closer to input size plus output size by adding an extra loop that iterates over spatial sub-chunks

@delayed
def apply_to_bunch(func: Callable, *blocks, work_size_xy: tuple[int,int] = (32, 32)):
   out = np.empty(out_shape(blocks))
   for roi in spatial_sub_blocks(work_size_xy):
      out[..., roi] = func(stack([b[..., roi] for b in blocks]))
   return out
valpesendorfer commented 3 days ago

As discussed, computing the SPI is probably a good example due to the access requirements across different chunks and the algorithm being fairly expensive.

Let's work off the assumption we need to calculate the "full archive" of SPIs. And for this example, we calculate 12 & 24 month SPI. We could calculate this directly from dekadal data by first calculating the rolling sum. But selected aggregations are already precomputed and can be used directly (like the 12-month one).

We start by using the 12-month aggregated dekadal rainfall:

from datacube import Datacube
import hdc.algo
from hdc.tools.utils.dcube import groupby

dc = Datacube()
x = dc.load(product="ryh_dekad", dask_chunks={}, group_by=groupby.start_time)
x = x.band

display(x)

image

The first "case" is to calculate the SPIs directly from the aggregations. SPIs are calculated for each dekad in the year across the full archive, e.g. to calculate the SPI for 202401d1, we need the data for all dekads 01d1 across the entire record:

# Note: I'm using the dekad accessor to select the first dekad in the year, which corresponds to `01d1`

xx = x.sel(time=x.time.dekad.yidx == 1)
display(xx)

image

To calculate the SPI, I need to rechunk the data to have the full time dimension in one block:

xxspi = xx.chunk(time=-1, latitude=1000, longitude=1200).hdc.algo.spi()
display(xxspi)

image

To avoid having to iterate over the individual dekads, hdc-algo>=0.3.0 introduced the grouped SPIs, where we can supply the array of dekad indices and the SPIs are calculated for each individual group (i.e. dekad). The issue here is the entire time dimension needs to be in memory beforehand, so a lot of memory is being used by loading the full dataset, while each computation only uses a subset of the individual blocks:

xxspi = x.chunk(time=-1, latitude=200, longitude=200).hdc.algo.spi(groups=x.time.dekad.yidx)
display(xxspi)

So the initial idea for improvements was to pack the blocks in a way they contain only the necessary data for computing the individual dekads and then iterate over the blocks.

image

The second "case" is calculating SPIs which have aggregation periods longer than 12 months (or similarly calculating SPIs directly from the non aggregated dekadal data), in which case the data needs to be aggregated before using a rolling sum before calculating the SPI.

For example, we can calculate the 24 month SPI from the 12 month aggregated rainfall by either using .rolling on the xarray data array or the .hdc.rolling.sum accessor introduced in hdc-algo>=0.3, which again requires the full time dimension in one chunk beforehand:

xx = x.chunk(time=-1, latitude=200, longitude=200).hdc.rolling.sum(window_size=2, dtype="int16")
xxspi = xx.hdc.algo.spi(groups=xx.time.dekad.yidx)
display(xxspi)

image

Full notebook ```python # --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: light # format_version: '1.5' # jupytext_version: 1.16.1 # kernelspec: # display_name: hdc # language: python # name: conda-env-hdc-py # --- from datacube import Datacube import hdc.algo from hdc.tools.utils.dcube import groupby dc = Datacube() x = dc.load(product="ryh_dekad", dask_chunks={}, group_by=groupby.start_time) x = x.band # + # Note: I'm using the dekad accessor to select the first dekad in the year, which corresponds to `01d1` # - xx = x.sel(time=x.time.dekad.yidx == 1) display(xx) xxspi = xx.chunk(time=-1, latitude=1000, longitude=1200).hdc.algo.spi() display(xxspi) xxspi = x.chunk(time=-1, latitude=200, longitude=200).hdc.algo.spi(groups=x.time.dekad.yidx) display(xxspi) xx = x.chunk(time=-1, latitude=200, longitude=200).hdc.rolling.sum(window_size=2, dtype="int16") xxspi = xx.hdc.algo.spi(groups=xx.time.dekad.yidx) display(xxspi) # + x.rolling(time=2).sum().round().astype("int16") # - ```