xarray-contrib / xbatcher

Batch generation from xarray datasets
https://xbatcher.readthedocs.io
Apache License 2.0
167 stars 27 forks source link

Add ability to shuffle (and reshuffle) batches #170

Open arbennett opened 1 year ago

arbennett commented 1 year ago

Description of proposed changes

This relatively simple addition just adds the shuffle flag and reshuffle method to allow for randomizing the ordering of batches. This can be useful to reduce the effect of auto-correlation between samples that are nearby in space/time. The way I've implemented it is to simply preemptively turn the patch_selectors into a list which might not be optimal. But, in my testing, these are usually explicitly loaded at some point before the batch generator is iterated over anyhow so hopefully that's not a huge blocker.

Fixes # <--- I thought there used to be an issue around this, but I was unsuccessful in finding it. I'll update this if someone links the relevant issue.

arbennett commented 1 year ago

I've been using this long enough on my own work that I think it's behaving as intended. If the code/approach is good I would be happy to add some tests.

weiji14 commented 1 year ago

Going on a bit of a tangent, but continuing on a bit from https://github.com/xarray-contrib/xbatcher/issues/176#issuecomment-1478695567, have you tried Shuffler in torchdata? My impression with torchdata's Shuffler is that it's slow, so maybe it is worth adding a shuffle option in xbatcher, but just wanted to know your experience so far.

arbennett commented 1 year ago

I'll give it a shot! Apparently I need to dig into the torchdata docs a bit more closely :sweat_smile:

arbennett commented 1 year ago

I tried using the built in torchdata shuffler and, at least for subsampling from a large zarr file, it is extremely slow. Using the method implemented here is much faster/lightweight.

weiji14 commented 1 year ago

I tried using the built in torchdata shuffler and, at least for subsampling from a large zarr file, it is extremely slow. Using the method implemented here is much faster/lightweight.

Hmm yes, that's what I expected. You could change the buffer_size in Shuffler from 10000 to a smaller number, but this would still be slower for the reason below.

What you've done in this xbatcher PR is essentially shuffling of the indexes (lightweight on RAM). With torchdata's Shuffler, you would be shuffling the arrays (heavy on RAM), unless you find a way to get in between the slicing and batching part.

This sort of ties in to my proposal at https://github.com/xarray-contrib/xbatcher/issues/172 on decomposing xbatcher into a Slicer and Batcher. So instead of a slow datapipe like Slicer -> Batcher -> Shuffler, you would be doing Slicer -> Shuffler -> Batcher that is faster (as you've proved in this PR) because it just randomizing the indexes/pointers to the arrays.