SHI-Labs / Compact-Transformers

Escaping the Big Data Paradigm with Compact Transformers, 2021 (Train your Vision Transformers in 30 mins on CIFAR-10 with a single GPU!)
https://arxiv.org/abs/2104.05704
Apache License 2.0
500 stars 79 forks source link

Thank you for your nice work | Question on Flowers dataset #65

Closed JosephKJ closed 2 years ago

JosephKJ commented 2 years ago

Hi @alihassanijr,

Many thanks for your super interesting work, and sharing the elegant code with the community.

I am able to replicate your CIFAR-10 and CIFAR-100 results perfectly. But, there is a large gap when it comes to the Flowers dataset.

After running the following command:

python train.py -c configs/datasets/flowers102.yml --model cct_7_7x2_224_sine ./data/flowers102 --log-wandb

I am able to get only 62% accuracy. Please find the wandb report here. I am attaching the logs too:
output.log

The only change that I made to the code was to use the PyTorch dataloaders:

from torchvision.datasets import Flowers102
dataset_train = Flowers102(root=args.data_dir, split="train", download=True)
dataset_eval = Flowers102(root=args.data_dir, split="test", download=True)

I am sure that this might be some minor configuration issue for Flowers Dataset, as I am able to replicate the results on CIFAR-10 and CIFAR-100.

Thanks again, and it would be very kind of you if you could help me.

Thanks, Joseph

JosephKJ commented 2 years ago

Small update, I just ran the same experiment with learnable positional embedding, the results are similar to what we observed with sine-wave based embedding.

alihassanijr commented 2 years ago

Hello and thank you for your interest.

So a few notes: 1. At the time we downloaded Flowers-102 as an image dataset (split into class subdirectories) and loaded it with the ImageFolder dataset class. The latter is unlikely to be the reason, but the former could be the issue here. I'm not sure why that would be the case, but I'll look into it and check back here with what I can find. The other thing I'm noticing is that you're using the test set to evaluate, while we used the validation split. One more item that the wandb report doesn't let me see is world size: how many GPUs are you using to train? One last thing: are you keeping the transforms the same? Because if you're not, the transforms we specified, especially the mean and std for normalization will not be applied.

Now any of these could be the reason why you're running into the issue, we'll look into them, but I thought I'd share these in case you'd like to try changing these in the meantime.

JosephKJ commented 2 years ago

Hi @alihassanijr,

Many thanks for your response.

  1. At the time we downloaded Flowers-102 as an image dataset (split into class subdirectories) and loaded it with the ImageFolder dataset class. The latter is unlikely to be the reason, but the former could be the issue here. I'm not sure why that would be the case, but I'll look into it and check back here with what I can find.

    I will check the data-loading and keep it similar to yours.

  2. The other thing I'm noticing is that you're using the test set to evaluate, while we used the validation split.

    I just tried evaluating using the val split. Here is the result: W B Chart 13_09_2022, 22_05_41

  3. One more item that the wandb report doesn't let me see is world size: how many GPUs are you using to train?

    I use one GPU. I was able to reproduce CIFAR-10 and CIFAR-100 using the same GPU.

  4. One last thing: are you keeping the transforms the same? Because if you're not, the transforms we specified, especially the mean and std for normalization will not be applied.

    Yes, I didnt touch that part. I am using the default value provided in flowers102.yaml.

Thanks again for your response, I will update here incase I succeed in replicating the Flowers result. Thank you!

alihassanijr commented 2 years ago

Okay, so the number of GPUs might be the reason. We only trained CIFAR-10/100 and MNIST datasets on single GPUs. The rest, including Flowers, were trained on 8 GPUs with distributed training. I would recommend you give that a shot. Note that you don't "have to" use 8, as long as you maintain the batch size 768 = 96 x 8.

JosephKJ commented 2 years ago

Thanks again for your response. I will try using a batch size of 768 and let you know. Thank you!

JosephKJ commented 2 years ago

Hi @alihassanijr, I did a few more experiments:

1) used a batch size of 768 as you suggested. 2) ran for 8x more epochs with 8x lower Learning Rate following this paper.

Here is the result: W B Chart 14_09_2022, 10_59_10

alihassanijr commented 2 years ago

Hi, could you try it with the same learning rate and epochs? The training script we use doesn't factor in batch size into epochs or LR, so I'd suggest just increasing the batch size until the total reaches 768.

JosephKJ commented 2 years ago

The green line above uses the same LR and epochs, Ali. The red line is just increasing the batch size to 768.

Thanks again for your reply and willingness to help.

alihassanijr commented 2 years ago

Hi, So I think I have the issue narrowed down. I’m not sure why this discrepancy exists, but we seem to be working on different splits of Flowers-102. I just checked, and the torchvision dataset has 1020 training samples, 1020 validation samples, and 6149 test samples. The test set is also labeled. However, the copy we have was downloaded from Kaggle, because at the time Flowers-102 was not part of torchvision (added January 2022). Ours copy seems to have a different split: 6552 training samples, 818 validations samples, and 819 unlabeled test samples.

This is what’s causing the issue here.

Now upon looking further, we found that there’s a number of papers on the leadboard that train on the kaggle split, as opposed to the original (not many shared where they downloaded the dataset or even checkpoints). We also found that some papers merge the training and validation splits, and treat it as a training set with 2040 samples, and use the test set as a validation/test set.

Meanwhile, we tried training CCT, ViT, and CaiT variants from scratch on the torchvision split, and found that CCT is leading the other models of similar size about 20% in accuracy. You can check those in this WandB report.

JosephKJ commented 2 years ago

Thank you very much @alihassanijr, for taking time to identify whats going wrong. Really appreciate your help!