Open Guitaricet opened 5 months ago
Thanks for the feature request @Guitaricet . This has been a reccurent feature request for a while, but there is no current obvious path to support that just yet, as most kernel implementations are just not easily parallelizable over RNG. vmap
is a potential avenue, but last time I tried I ended up quickly blocked.
You'll see torchaug mentioned in https://github.com/pytorch/vision/issues/2929#issuecomment-1991650510 which supports what you need. Some kernels are natively implemented with support for per-sample RNG, but for most transforms the batches are instead split into sub-batches, transformed and then re-concatenated, so it's more of a per-subbatch-RNG AFAIU. If you try it and it works for your use-case I'd love to hear it, and perhaps consider something simliar in torchvision (CC @juliendenize who's the author of torchaug).
Thank you, I will take a look!
Hi thanks for the ping @NicolasHug , you're right about how Torchaug works it offers the flexibility to do per-subbatch-RNG when a per-sample RNG is not possible in Torch.
I'd like to also hearing from you @Guitaricet if you try it and if it works, I try to maintain it as much as I can and follow Torchvision advances but I'd need more feedback from users to make it better. I'd love it if eventually Torchvision could natively support such features :smile:.
🚀 The feature
Randomly sample augmentation parameters for each image in the batch
Motivation, pitch
Torch augmentations have two big advantages over alternatives like Albumentations: ability to apply them to a batch of images and ability to run them on GPU. This can make a big difference when batch size is very large (e.g., 1000). However, right now augmentation pipeline would apply exactly the same transformation to every image in the batch, which makes augmentations less valuable and can destabilize the training because of significantly chaged batch statistics.
Alternatives
Any recommendations to existing libraries that allow batched + GPU augmentations?
Additional context
In some of my training pipelines I observed that Albumentations are not fast enough and that we can benefit from going back to torchvision, but then I noticed that all images are augmented in the same way.