VioletEqz / rE-ConvNeXt

7 stars 1 forks source link

Augmentation too strong #26

Open tokudayo opened 2 years ago

tokudayo commented 2 years ago

The current augmentation is so strong that the model barely learns anything (high loss, accuracy increases like <1% per epoch and cap at 10%, etc.). Once turning off all augmentations the model learns normally and quickly overfit the train set as expected. So we need to adjust the augmentation strength suitable for CIFAR.

tokudayo commented 2 years ago

Either the augmentation is too strong or there is some bug in the augmentation process. At the moment I'm not too sure.

VioletEqz commented 2 years ago

Affirmative, the default augmentation is tuned to ImageNet so it's no surprise it's not performing well with CIFAR. Just like you said, we should adjust it down a notch so that it would be suitable for CIFAR. I propose we drop out some of the augmentation, mostly those that won't have a big impact on a small image size.

tokudayo commented 2 years ago

The first and foremost problem is AA policy imo. Applying RA(9, 0.5) on an 32x32 image would distort it too much. Consider switching to an appropriate policy for CIFAR instead.

Edit: The params used actually means M=9, N=2 (implicitly) and M-std=0.5

Edit 2:

The first and foremost problem is AA policy imo

And it is probably not the main problem.

VioletEqz commented 2 years ago

And it is probably not the main problem.

That's really strange, I have been suspecting mainly the AA as well. Did you reach your conclusion after using no AA at all ? I have looked for augmentation setting for CIFAR on other SOTAs that we might replicate from, but seems like it wasn't fruitful at all.

tokudayo commented 2 years ago

Tried many things but no clear conclusion. The changes that seem to make it better is to lower the magnitude of AA and to tune mixup_alpha to 0.2 instead of 0.8. The difference was noticeable but not too significant.

The problem may lie somewhere else or this is just the reality (it starts slow but then converges eventually).

tokudayo commented 2 years ago

The commonly theme is the validation acc is always higher than training acc, which means the examples in training set are augmented to be much harder to classify. So I tried turning off all augmentations:

transform = create_transform(
        input_size=32,
        is_training=is_training,
        auto_augment=None,
        interpolation='bicubic',
        re_prob=0.0,
        re_mode='pixel',
        re_count=1,
        mean=(0.5071, 0.4867, 0.4408),
        std=(0.2675, 0.2565, 0.2761)
    )
    if is_training:
        mixup = Mixup(
            mixup_alpha=0.2, 
            cutmix_alpha=1.0, 
            prob=0.0,   
            switch_prob=0.5, 
            mode='batch',
            label_smoothing=0.1, 
            num_classes=100
        )
...

But the sample problem persists. Only when I explicitly turn off training mode transform = CIFAR100_augmentation(is_training=False) does the training pattern occurs normally, ie. quickly overfits the training set, train acc always > val acc.

So it means that even when I thought I turned off all augmentations (first block of code above), the images were still heavily (or not) augmented in some way. Very confused send help.

VioletEqz commented 2 years ago

At this point I think there is something wrong with this implementation of augmentation using timm. I suggest we use torchvision.transforms for now first, at least with the basic transformation first to see if the problem still persists. Also, looking at their code, it seems color_jitter and hflip, vflip has default values that aren't 0. Maybe we should check that as well.

def transforms_imagenet_train(
        img_size=224,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        interpolation='random',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_num_splits=0,
        separate=False,
):

Their code is here.

tokudayo commented 2 years ago

Ah yes Initially I removed color_jitter because it would be automatically disabled if AA is on, but then I disabled AA

VioletEqz commented 2 years ago

Also it seems like these lines from the function:

scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range
ratio = tuple(ratio or (3./4., 4./3.))  # default imagenet ratio range

which default to imagenet's value might be the culprit as well, since we didn't touch them at all in our implementation. The fact that when is_training=false, the code switch to a different function which doesn't use these two make them even more suspicious.

tokudayo commented 2 years ago

Wow yeah. There might be other problems as well because iirc create_transform is a wrapper for imagenet_preprocess/transform or something.

tokudayo commented 2 years ago

For the image transform, filling in all the params provides better control:

  transform = create_transform(
      input_size=32,
      is_training=is_training,
      scale=(1., 1.),
      ratio=(1., 1.),
      hflip=0.5,
      vflip=0.,
      color_jitter=0.,
      auto_augment='rand-m2-n1-mstd0.5',
      interpolation='bicubic',
      re_prob=0.5,
      re_mode='pixel',
      re_count=1,
      mean=(0.5071, 0.4867, 0.4408),
      std=(0.2675, 0.2565, 0.2761)
  )

Subsequentially turning off each augmentation (also mixup+cutmix below) gives faster overfitting on train set.

I recommend visually inspeting some batches of samples and adjust the augmentations accordingly. For small (32x32) images, it is bad to stack too many augs.