Sample different random parameters of Mixup/Cutmix for different elements of the batch to avoid loss instability in large batch setups.
Motivation, pitch
Hello!
My understanding is that only one random sampling of random parameters is done for the entire batch in Mixup/Cutmix, which leads all elements of the batch to have the same augmentation. For example, when using CutMix, all images in the batch end up with the same bounding box at the exact same location.
This is particularly hurting when using very large batch as it leads to unstable training. In one batch, you get lucky, you get easy parameters and the loss is already low, then in the next batch, you get unlucky and the parameters are super hard and the loss is super high. If transform parameters were sampled per element, you would get an averaging effect that mitigates this issue in the case of large batch sizes.
This could be as easy as replacing the various instances of self._dist.sample(()) by self._dist.sample((batch_size,)) and a bit more involved for the bounding boxes, but nothing really outstandingly hard.
🚀 The feature
Sample different random parameters of Mixup/Cutmix for different elements of the batch to avoid loss instability in large batch setups.
Motivation, pitch
Hello!
My understanding is that only one random sampling of random parameters is done for the entire batch in Mixup/Cutmix, which leads all elements of the batch to have the same augmentation. For example, when using CutMix, all images in the batch end up with the same bounding box at the exact same location.
This is particularly hurting when using very large batch as it leads to unstable training. In one batch, you get lucky, you get easy parameters and the loss is already low, then in the next batch, you get unlucky and the parameters are super hard and the loss is super high. If transform parameters were sampled per element, you would get an averaging effect that mitigates this issue in the case of large batch sizes.
This could be as easy as replacing the various instances of self._dist.sample(()) by self._dist.sample((batch_size,)) and a bit more involved for the bounding boxes, but nothing really outstandingly hard.
Best! David.
Alternatives
No response
Additional context
No response