jeonsworld / ViT-pytorch

Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
MIT License
1.95k stars 374 forks source link

Training accuracy much lower than validation accuracy #28

Closed ganeshkumarashok closed 3 years ago

ganeshkumarashok commented 3 years ago

Thanks for creating and uploading this easily usable repo!

In addition to the validation accuracy on the entire validation set that is printed out by default, we printed out the training accuracies of the model and we observe that the training accuracy is 6-8% lower than the validation accuracy. Is that reasonable/accurate since we usually expect the training accuracy to be higher than the validation accuracy?

This was for a ViT-B_16 model, pretrained on ImageNet-21k and during the fine-tuning phase on CIFAR 10. To get the training accuracies, we used model(x)[0] to get the logits, loss and predictions for each batch and used the AverageMeter() to calculate the running accuracies. Additionally, to get the accurate training accuracy over the entire training set, we passed the training set to a copy valid() (with only changes to print statements). Both the running training accuracy and the training accuracy over the entire training set was lower than the validation accuracy by 6-8%. For instance, after 10k steps, training accuracy was 92.9% (over entire train set) and validation accuracy was 98.7%. We used most of the default hyperparameters (besides batch size and fp_16) and did not make other changes to the code.

Please let us know if this lower training accuracy is expected or if its calculation is incorrect. Thanks in advance.

ganeshkumarashok commented 3 years ago

The discrepancy is reasonable as the dataset used for training has additional augmentations that are not present in the transforms for the validation set. Using the same transforms for both the train and validation sets leads to the expected pattern of training accuracy being higher than validation accuracy.