alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
72 stars 15 forks source link

Patchwise training and inference #22

Open tom-andersson opened 1 year ago

tom-andersson commented 1 year ago

Some deepsensor users may have dense environmental data spanning large spatial areas. For example:

In such cases, training and inference with a ConvNP over the entire region of data may be computationally prohibitive. Currently, the TaskLoader will sample context and target data over the entire spatial region that data is available, which could produce OOM issues. So we need to support chopping the data into smaller spatial patches.

Training Supporting patchwise ConvNP training should just be a matter of updating the TaskLoader to slice the context and target datasets spatially to subsetted squares/regions before proceeding with the TaskLoader.__call__ sampling functionality for generating Task objects. I believe this should be quite simple: for xarray data this would be ds.sel(x1=slice(...), x2=slice(...)), while for pandas data it would be df.loc[slice(...), slice(...)].

Inference Inference using the high-level DeepSensorModel.predict interface also needs support for patching. This requires functionality to stitch all the individual patch predictions together.

For on-grid xarray prediction, one solution might be to call .predict recursively over all the patches and then concatenate the resulting xr.Datasets into single objects. This would require some kind of patchify bool to control this and avoid infinite recursion within the inner call. Open to other ideas!

However, model predictions could differ substantially from one side of a patch border to another (due to differing context information in each patch). We therefore may need to think about having overlapping patches and averaging model predictions somehow.

Patch size/location question An open question is how the size and location of the patches should be determined. One option is to have the user pass the patch size in TaskLoader.__call__ or DeepSensorModel.predict, and then the location will be generated randomly unless further kwargs are passed to override this and specify exact x1/x2 spatial bounds.

nilsleh commented 12 months ago

Hi @tom-andersson , I am attempting to "prototype" this for my use case I have. The question I have is what the interpretation of a patch_size should be (ignoring the location of the patch for the moment and just going with random). In the case of spatial Xarray and the example of patch_size=100 and an xarry spanning the globe of dimension [180, 360]:

Maybe you have some thoughts/preferences on this that I could take as a further guide for a draft implementation. (I haven't thought about Inference at all at this stage)

acocac commented 5 months ago

Just to add, the Pangeo ML group has extensively worked in optimising n-dimensional arrays for AI/ML pipelines. I suggest you consider for the patch-wise training to build upon existing developments such as xbatcher and zen3geo python libraries.