lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
20.27k stars 3.02k forks source link

Trained on small dataset with pre-trained weight, don't have good result. #34

Open JamesQFreeman opened 3 years ago

JamesQFreeman commented 3 years ago
pretrained_v = timm.create_model('vit_base_patch16_224', pretrained=True)
pretrained_v.head = nn.Linear(768,2)

I tried Kaggle Cats vs Dogs Dataset for binary classification. Didn't work, output is all cat or all dog.

Any idea how to make it work at small dataset? (less than 10000 or even less than 1000)

PS: Adam, lr = 1e-2

JamesQFreeman commented 3 years ago

update: "Didn't work, output is all cat or all dog." was trained on only 1k images. Now I train the ViT on whole dataset which have 20k images and it kind of works.

0.73 acc @ 10 epochs, 45 mins on RTX Titan(another same run used 100 mins on Titan X), Not very amazing comparing with CNN so far.

image
JamesQFreeman commented 3 years ago
v = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 2,
    dim = 512,
    depth = 4,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

BTW, using non pre-trained model above, I got around 0.8 acc using the same amount of time.

JamesQFreeman commented 3 years ago

The pertained model had a peak acc of 0.796 after 100 epochs of training. In this dataset, resnet50 can reach 90 without any modification. Is there any tuning trick I can use?

lucidrains commented 3 years ago

Hi James, attention excels in the regime of big data, as shown in the paper. However, I am curious why fine tuning did not work. Are you using Ross' model? Perhaps submit an issue at his repository?

JamesQFreeman commented 3 years ago

Yes, Ross' model (which is uploaded to timm) is used. Is pretrained model always work on small dataset?

lucidrains commented 3 years ago

@JamesQFreeman I think fine-tuning from a pretrained model should generally work well. maybe you should raise the issue with him

lucidrains commented 3 years ago

@JamesQFreeman ohh... well, I think I spot the error, your learning rate is way too high 1e-2, try Karpathy's favorite LR, 3e-4

JamesQFreeman commented 3 years ago

@JamesQFreeman ohh... well, I think I spot the error, your learning rate is way too high 1e-2, try Karpathy's favorite LR, 3e-4

Thanks! I'll give a try.

Lin-Zhipeng commented 3 years ago

I also tried the experiment. lr = 3e-5 batch_size = 8

Epoch : 1 - loss : 0.0648 - acc: 0.9752 - val_loss : 0.0592 - val_acc: 0.9782 Epoch : 2 - loss : 0.0561 - acc: 0.9773 - val_loss : 0.0531 - val_acc: 0.9790 Epoch : 3 - loss : 0.0513 - acc: 0.9795 - val_loss : 0.0677 - val_acc: 0.9750 Epoch : 4 - loss : 0.0473 - acc: 0.9809 - val_loss : 0.0479 - val_acc: 0.9804 Epoch : 5 - loss : 0.0473 - acc: 0.9800 - val_loss : 0.0567 - val_acc: 0.9780 Epoch : 6 - loss : 0.0466 - acc: 0.9806 - val_loss : 0.0526 - val_acc: 0.9780 Epoch : 7 - loss : 0.0413 - acc: 0.9826 - val_loss : 0.0615 - val_acc: 0.9774 Epoch : 8 - loss : 0.0430 - acc: 0.9833 - val_loss : 0.0619 - val_acc: 0.9746 Epoch : 9 - loss : 0.0411 - acc: 0.9832 - val_loss : 0.0616 - val_acc: 0.9784 Epoch : 10 - loss : 0.0450 - acc: 0.9824 - val_loss : 0.0483 - val_acc: 0.9830 Epoch : 11 - loss : 0.0374 - acc: 0.9842 - val_loss : 0.0598 - val_acc: 0.9746 Epoch : 12 - loss : 0.0393 - acc: 0.9844 - val_loss : 0.1202 - val_acc: 0.9602 Epoch : 13 - loss : 0.0418 - acc: 0.9830 - val_loss : 0.0547 - val_acc: 0.9806 Epoch : 14 - loss : 0.0380 - acc: 0.9846 - val_loss : 0.0578 - val_acc: 0.9760 Epoch : 15 - loss : 0.0376 - acc: 0.9852 - val_loss : 0.0557 - val_acc: 0.9786 Epoch : 16 - loss : 0.0372 - acc: 0.9845 - val_loss : 0.0595 - val_acc: 0.9790 Epoch : 17 - loss : 0.0379 - acc: 0.9846 - val_loss : 0.0560 - val_acc: 0.9802 Epoch : 18 - loss : 0.0353 - acc: 0.9859 - val_loss : 0.0561 - val_acc: 0.9818 Epoch : 19 - loss : 0.0361 - acc: 0.9860 - val_loss : 0.0482 - val_acc: 0.9810 Epoch : 20 - loss : 0.0349 - acc: 0.9864 - val_loss : 0.0547 - val_acc: 0.9792

emmmmm,not bad. I think it will better if i can tunning the parameter.

XA-kirino commented 3 years ago

lower learning rate and SGD are better for fine-tuning, don't use Adam

myt889 commented 3 years ago

I also tried the experiment. lr = 3e-5 batch_size = 8

Epoch : 1 - loss : 0.0648 - acc: 0.9752 - val_loss : 0.0592 - val_acc: 0.9782 Epoch : 2 - loss : 0.0561 - acc: 0.9773 - val_loss : 0.0531 - val_acc: 0.9790 Epoch : 3 - loss : 0.0513 - acc: 0.9795 - val_loss : 0.0677 - val_acc: 0.9750 Epoch : 4 - loss : 0.0473 - acc: 0.9809 - val_loss : 0.0479 - val_acc: 0.9804 Epoch : 5 - loss : 0.0473 - acc: 0.9800 - val_loss : 0.0567 - val_acc: 0.9780 Epoch : 6 - loss : 0.0466 - acc: 0.9806 - val_loss : 0.0526 - val_acc: 0.9780 Epoch : 7 - loss : 0.0413 - acc: 0.9826 - val_loss : 0.0615 - val_acc: 0.9774 Epoch : 8 - loss : 0.0430 - acc: 0.9833 - val_loss : 0.0619 - val_acc: 0.9746 Epoch : 9 - loss : 0.0411 - acc: 0.9832 - val_loss : 0.0616 - val_acc: 0.9784 Epoch : 10 - loss : 0.0450 - acc: 0.9824 - val_loss : 0.0483 - val_acc: 0.9830 Epoch : 11 - loss : 0.0374 - acc: 0.9842 - val_loss : 0.0598 - val_acc: 0.9746 Epoch : 12 - loss : 0.0393 - acc: 0.9844 - val_loss : 0.1202 - val_acc: 0.9602 Epoch : 13 - loss : 0.0418 - acc: 0.9830 - val_loss : 0.0547 - val_acc: 0.9806 Epoch : 14 - loss : 0.0380 - acc: 0.9846 - val_loss : 0.0578 - val_acc: 0.9760 Epoch : 15 - loss : 0.0376 - acc: 0.9852 - val_loss : 0.0557 - val_acc: 0.9786 Epoch : 16 - loss : 0.0372 - acc: 0.9845 - val_loss : 0.0595 - val_acc: 0.9790 Epoch : 17 - loss : 0.0379 - acc: 0.9846 - val_loss : 0.0560 - val_acc: 0.9802 Epoch : 18 - loss : 0.0353 - acc: 0.9859 - val_loss : 0.0561 - val_acc: 0.9818 Epoch : 19 - loss : 0.0361 - acc: 0.9860 - val_loss : 0.0482 - val_acc: 0.9810 Epoch : 20 - loss : 0.0349 - acc: 0.9864 - val_loss : 0.0547 - val_acc: 0.9792

emmmmm,not bad. I think it will better if i can tunning the parameter.

Hi,you really got good results with acc,I wonder how did you do that?I mean the datasets and the pre-trained weight.

abhigoku10 commented 3 years ago

@JamesQFreeman @myt889 @Lin-Zhipeng can you share your train.py to load it from pretrained model , it will be very helpful did u try loading the varriant of the models ??