google-research / vision_transformer

Apache License 2.0
10.39k stars 1.29k forks source link

Hyper-parameters of ViT-B/16 training from scratch #2

Closed liuyuyuil closed 4 years ago

liuyuyuil commented 4 years ago

Thanks for sharing your code. Can you provide the hyper-parameters (e.g. learning rate, weight decay, optimizer type, training epochs) of ViT-B/16 training from scratch on ImageNet dataset? Many thanks.

andsteing commented 4 years ago

Note that for the published checkpoints we pretrained on imagenet21k (see README), using ~102.4M~ 12.4M examples for training.

As for the hyper parameters:

batch_size=4096
lr.base=1e-3
lr.decay_type=linear
lr.linear_end=1e-5
lr.warmup_steps=10_000
dropout_rate=0.1
num_epochs=90
weight_decay=0.03
optimizer=Adam
representation_size=768

We used the same cropping code but an image size of 224 (thus 14x14 grid).

The model was exactly the same, other than the additional penultimate layer with dimensionality representation_size. The final classification layer's bias weights were initialized with -10.

liuyuyuil commented 4 years ago

Thanks for reply !

liuyuyuil commented 4 years ago

By the way, what's the top1 accuracy of ViT-B/16 training from scratch on ImageNet with an image size of 224 ? There is a statement in the paper

'With self-supervised pre-training, our smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch'

Is it 77.9% ? Thanks.

andsteing commented 4 years ago

The 79.9% refers to the self-supervised pretraining - see B.1.2. in the appendix for details. The B/16 model pre-trained and fine-tuned on imagenet2012 achieves 77.9% (see table 5 in the appendix).

andsteing commented 3 years ago

That was a typo (now corrected) - it should have said 12.4M examples. See this comment for more details.

cissoidx commented 3 years ago

The 79.9% refers to the self-supervised pretraining - see B.1.2. in the appendix for details. The B/16 model pre-trained and fine-tuned on imagenet2012 achieves 77.9% (see table 5 in the appendix).

What is the top1 acc of pretraining (without finetuning) on imagenet2012?

andsteing commented 3 years ago

top1 acc (evaluated on 50k holdout from training set) at the end of the pre-training from the original ViT paper was as follows:

name val_acc
ViT-B/32 i1k 69.19%
ViT-B/16 i1k 74.79%
ViT-L/32 i1k 66.90%
ViT-L/16 i1k 72.59%

Note that we have much more detail about pre-training from scratch in the paper How to train your ViT?..., check out the database in our Colab:

https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb

For example, to show you the final pre-training top1 accuracy of a variety of models and pre-training settings:

import plotly.express as px

px.scatter(
    df.drop_duplicates('filename').query('ds=="i1k" and final_val>0.4'),
    y='final_val',
    x='aug',
    color='wd',
    symbol='do',
    facet_col='name',
    facet_col_wrap=4,
)
cissoidx commented 3 years ago

Note that for the published checkpoints we pretrained on imagenet21k (see README), using ~102.4M~ 12.4M examples for training.

As for the hyper parameters:

batch_size=4096
lr.base=1e-3
lr.decay_type=linear
lr.linear_end=1e-5
lr.warmup_steps=10_000
dropout_rate=0.1
num_epochs=90
weight_decay=0.03
optimizer=Adam
representation_size=768

We used the same cropping code but an image size of 224 (thus 14x14 grid).

The model was exactly the same, other than the additional penultimate layer with dimensionality representation_size. The final classification layer's bias weights were initialized with -10.

Hi, do you use the same hyperparameters in pretraining imagenet1k?

andsteing commented 3 years ago

For Imagenet1k pre-training we used the following hparams different from the hparams used for pre-training on Imagenet21k:

grad_clip_norm=1.0
lr.base=3e-3
lr.decay_type=cosine
dropout_rate=0.1  # B/16, B/32
dropout_rate=0.2  # L/16, L/32
num_epochs=300  # B/16, B/32
weight_decay=0.3

(note that training L/16 and L/32 on i1k can be tricky; you might want to reduce the number of epochs or augment data as described in How to train your ViT? paper)