xarray-contrib / xbatcher

Batch generation from xarray datasets
https://xbatcher.readthedocs.io
Apache License 2.0
167 stars 27 forks source link

Support for valid examples #158

Open ljstrnadiii opened 1 year ago

ljstrnadiii commented 1 year ago

Is your feature request related to a problem?

There is currently no support to serve batches that satisfy some valid criteria. It would be nice to filter out batches based on some criteria such as:

Consider this dataset:

import xarray as xr
import dask.array as da
import numpy as np

w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])

# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)

bgen = xbatcher.BatchGenerator(
    da, 
    {'variable': 2, 'x':10, 'y': 10}, 
    input_overlap={'x': 0, 'y': 0}, 
    batch_dims={'x': 100, 'y': 100}, 
    concat_input_dims=True
)

for batch in bgen:
    pass

If we are serving this to a machine learning process and we only care about where we have target data. Many of these examples will not be valid i.e. there will be no target value to use for training.

Describe the solution you'd like

I would like to see something like:

w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])

# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)

bgen = xbatcher.BatchGenerator(
    da, 
    {'variable': 2, 'x':10, 'y': 10}, 
    input_overlap={'x': 0, 'y': 0}, 
    batch_dims={'x': 100, 'y': 100}, 
    concat_input_dims=True,
    valid_example=lambda x: ~np.isnan(x[0][5,5])
)

for batch in bgen:
    pass

where we satisfy: np.all(~np.isnan(batch[:,0,5,5]))

Describe alternatives you've considered

see: https://discourse.pangeo.io/t/efficiently-slicing-random-windows-for-reduced-xarray-dataset/2447

I typically filter out all valid "chips" or "patches" in advance and persist as a "training dataset" to get all the computation out of the way. The dims would look something like {'i': number of valid chips, 'variable': 2, 'x': 10, 'y': 10}. I could then simply use xbatcher to batch on the ith dimension.

Additional context

No response

weiji14 commented 1 year ago

Agree that there should be a way to filter out invalid values. There's a newer duplicate issue at #162 on having a predicate function (had to look up https://dcl-prog.stanford.edu/function-predicate.html to know that predicate functions are those that return a True or False (i.e. boolean)) similar to the valid_example parameter you are proposing here, but I'll post here on a first come first serve basis.

At https://github.com/xarray-contrib/xbatcher/issues/162#issuecomment-1431902345, @cmdupuis3 showed this example code snippet:

Better code sample, which wraps xbatcher and also offers fixed batch sizes:

bgen = xb.BatchGenerator(
    ds,
    {'d1':5, 'd2':5},
    {'d1':2, 'd2':2}
)
def my_gen2(bgen, batch_size=5, predicate=None):
    b = (batch for batch in bgen)
    n = 0
    batch_stack = []
    while n < 400: # hardcoded n is a kludge; while-loop is necessary
        this_batch = next(b)
        if not predicate or predicate(this_batch):
            batch_stack.append(this_batch)
            n += 1
        else:
            n += 1
            continue
        if len(batch_stack) == batch_size:
            yield xr.concat(batch_stack, 'sample')
            batch_stack = []

This code can be summarized as 3 main steps:

  1. Use xbatcher.BatchGenerator to generate the chips/patches
  2. Filter out invalid values based on a predicate True/False condition
  3. Use xr.concat to create the stack of tensors

The fact that someone has to concat the tensors together after having already used BatchGenerator (which according to its name, should be for generating batches) indicates that BatchGenerator is sometimes used for half of the job (the chipping/slicing part). I've had to do the same xbatcher.BatchGenerator + concat workflow at https://zen3geo.readthedocs.io/en/v0.5.0/chipping.html#pool-chips-into-mini-batches, so this isn't an isolated incident.

While we could add a valid_example parameter to filter out NaNs or invalid values, my suggestion is to follow the torchdata compositional style and have a Slicer, Filter and Batcher do each of the 3 steps above. The reasoning is laid out in #172, and is because valid_example would not be the only parameter people would like to add, there's also caching at #109, creating train/val/test splits, shuffling, and so on, which would lead to an overly complicated BatchGenerator.

That said, we could theoretically add a valid_example filter parameter quite easily now, and handle all the extra Slicer/Filter/Batcher stuff in the background hidden from the user. This is if people are interested in using xbatcher.BatchGenerator as a 'one-liner' that does everything similar to something like pandas.read_csv.

ljstrnadiii commented 1 year ago

@weiji14 thanks for showing interest in this problem!

The term 'predicate function' makes way more sense and I should have used that terminology from the start.

The main issue I see with the three additional steps is that the predicate gets applied to the batches sequentially and we lose the parallel and potentially distributed power of dask, which is critical for decently-scaled ML problems.

I sometimes have >1tb sized dataarrays with dims (variable, y, x) with 10% valid xy coordinates. The target variable that is sparse might be 10+gbs and all that would have to come down sequentially to apply the predicate. Instead of trying to get BatchGenerator to solve this, I create "Training Datasets" with the first dimension being the batched dimension in advance. We persist to zarr or to cluster memory because we also shuffle, which is relatively expensive op. Then we can iterate over the first dim for batching.

Not to open a can of worms, but I think adding a concept like "Training Dataset" to xbatcher to precompute costly predicate functions, reshaping/windowing and shuffling could help decouple the preprocessing from batch serving and be more performant. Then again, anyone can do this in advance and then use the BatchGenerator over the first dim in that dataset.

We still don't do this because even with all those ops out of the way, batch generator still only loads one batch into memory at a time unless it is already persisted (if this can be afforded). This could be fine if the dataset is persisted, but is limited. This is obviously out of scope, but relates to https://github.com/xarray-contrib/xbatcher/pull/161

cmdupuis3 commented 1 year ago

Hi @ljstrnadiii, thanks for elaborating on your workflow. Do you have something working now? I'm curious to see what you had to do to get this working in a parallel-performant way.

ljstrnadiii commented 1 year ago

@cmdupuis3

The biggest step in gains for my use case comes from computing the training dataset in advance where the first dim contains the dim to batch over.

something like

dset = xr.open_zarr(...).to_array() # (# variable, y, x)
# extract valid training examples with extract_training_examples
training_dataset = dset.map_blocks(extract_training_examples) # (# valid examples, #variable, ...)

# persist to zarr or to cluser to get ops out of the way
# if training dataset can not fit into distributed memory
training_dataset.to_zarr(...)
training_dataset = xr.open_zarr(new_persisted_training_dataset_zarr_path)
# or if it can fit into memory
training_dataset =  training_dataset.persist()

# then try various methods

Does that add any clarification?

cmdupuis3 commented 1 year ago

Yeah, that's a lot clearer, thank you!