pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.71k stars 6.87k forks source link

Per element sampling for Mixup/Cutmix #8191

Open davidpicard opened 6 months ago

davidpicard commented 6 months ago

🚀 The feature

Sample different random parameters of Mixup/Cutmix for different elements of the batch to avoid loss instability in large batch setups.

Motivation, pitch

Hello!

My understanding is that only one random sampling of random parameters is done for the entire batch in Mixup/Cutmix, which leads all elements of the batch to have the same augmentation. For example, when using CutMix, all images in the batch end up with the same bounding box at the exact same location.

This is particularly hurting when using very large batch as it leads to unstable training. In one batch, you get lucky, you get easy parameters and the loss is already low, then in the next batch, you get unlucky and the parameters are super hard and the loss is super high. If transform parameters were sampled per element, you would get an averaging effect that mitigates this issue in the case of large batch sizes.

This could be as easy as replacing the various instances of self._dist.sample(()) by self._dist.sample((batch_size,)) and a bit more involved for the bounding boxes, but nothing really outstandingly hard.

Best! David.

Alternatives

No response

Additional context

No response

NicolasHug commented 6 months ago

Thanks for the feature request @davidpicard , yes we should try to support a mode parameter (or similar) like in timm