Open tom-andersson opened 1 year ago
Hi @tom-andersson , I am attempting to "prototype" this for my use case I have. The question I have is what the interpretation of a patch_size
should be (ignoring the location of the patch for the moment and just going with random). In the case of spatial Xarray and the example of patch_size=100
and an xarry spanning the globe of dimension [180, 360]:
patch_size=(50,120)
? patch_size=0.5
would yield a random [90,180] arrayMaybe you have some thoughts/preferences on this that I could take as a further guide for a draft implementation. (I haven't thought about Inference at all at this stage)
Some
deepsensor
users may have dense environmental data spanning large spatial areas. For example:In such cases, training and inference with a
ConvNP
over the entire region of data may be computationally prohibitive. Currently, theTaskLoader
will sample context and target data over the entire spatial region that data is available, which could produce OOM issues. So we need to support chopping the data into smaller spatial patches.Training Supporting patchwise
ConvNP
training should just be a matter of updating theTaskLoader
to slice the context and target datasets spatially to subsetted squares/regions before proceeding with theTaskLoader.__call__
sampling functionality for generatingTask
objects. I believe this should be quite simple: forxarray
data this would beds.sel(x1=slice(...), x2=slice(...))
, while forpandas
data it would bedf.loc[slice(...), slice(...)]
.Inference Inference using the high-level
DeepSensorModel.predict
interface also needs support for patching. This requires functionality to stitch all the individual patch predictions together.For on-grid
xarray
prediction, one solution might be to call.predict
recursively over all the patches and then concatenate the resultingxr.Dataset
s into single objects. This would require some kind ofpatchify
bool to control this and avoid infinite recursion within the inner call. Open to other ideas!However, model predictions could differ substantially from one side of a patch border to another (due to differing context information in each patch). We therefore may need to think about having overlapping patches and averaging model predictions somehow.
Patch size/location question An open question is how the size and location of the patches should be determined. One option is to have the user pass the patch size in
TaskLoader.__call__
orDeepSensorModel.predict
, and then the location will be generated randomly unless further kwargs are passed to override this and specify exact x1/x2 spatial bounds.