zihangJiang / TokenLabeling

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"
Apache License 2.0
425 stars 36 forks source link

Pretrained weights for LV-ViT-T #23

Closed marc345 closed 2 years ago

marc345 commented 2 years ago

Hi,

Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

All the best, Marc

zihangJiang commented 2 years ago

Hi Marc,

Thanks for your interest, you can simply train LV-ViT-T with

CUDA_VISIBLE_DEVICES=4,5,6,7 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_t -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_top5_train_nfnet --token-label-size 14 --model-ema

We also provide the pre-trained LV-VIT-T checkpoint here for your reference.

marc345 commented 2 years ago

@zihangJiang Thanks for the information. The LV-ViT-T model initialized with the provided checkpoint only achieves an accuracy of 57%, am I missing something?

zihangJiang commented 2 years ago

You can run

python3 validate.py /path/to/imagenet/val  --model lvvit_t --checkpoint /path/to/lvvit_t.pth --no-test-pool --amp --img-size 224 -b 64

The log should be

Test: [   0/782]  Time: 0.972s (0.972s,   65.83/s)  Loss:  0.3701 (0.3701)  Acc@1:  92.188 ( 92.188)  Acc@5:  98.438 ( 98.438)
Test: [  50/782]  Time: 0.034s (0.162s,  394.67/s)  Loss:  0.8394 (0.7368)  Acc@1:  79.688 ( 84.222)  Acc@5:  96.875 ( 96.722)
Test: [ 100/782]  Time: 0.034s (0.159s,  403.65/s)  Loss:  0.3413 (0.6856)  Acc@1:  93.750 ( 85.597)  Acc@5:  98.438 ( 96.597)
Test: [ 150/782]  Time: 0.034s (0.152s,  422.18/s)  Loss:  1.1211 (0.6938)  Acc@1:  68.750 ( 85.048)  Acc@5:  92.188 ( 96.575)
Test: [ 200/782]  Time: 0.034s (0.149s,  429.19/s)  Loss:  0.6631 (0.7009)  Acc@1:  89.062 ( 84.359)  Acc@5:  95.312 ( 96.642)
Test: [ 250/782]  Time: 0.034s (0.152s,  419.71/s)  Loss:  0.0801 (0.7183)  Acc@1:  96.875 ( 84.008)  Acc@5: 100.000 ( 96.782)
Test: [ 300/782]  Time: 0.034s (0.154s,  415.10/s)  Loss:  0.8501 (0.7185)  Acc@1:  81.250 ( 84.224)  Acc@5:  92.188 ( 96.823)
Test: [ 350/782]  Time: 0.034s (0.152s,  421.04/s)  Loss:  0.3701 (0.7466)  Acc@1:  90.625 ( 83.342)  Acc@5: 100.000 ( 96.492)
Test: [ 400/782]  Time: 0.386s (0.151s,  422.95/s)  Loss:  0.4163 (0.8073)  Acc@1:  85.938 ( 81.944)  Acc@5: 100.000 ( 95.889)
Test: [ 450/782]  Time: 0.454s (0.151s,  423.10/s)  Loss:  0.4368 (0.8178)  Acc@1:  92.188 ( 81.676)  Acc@5:  98.438 ( 95.673)
Test: [ 500/782]  Time: 0.034s (0.150s,  427.98/s)  Loss:  0.4084 (0.8468)  Acc@1:  92.188 ( 80.888)  Acc@5:  98.438 ( 95.337)
Test: [ 550/782]  Time: 0.399s (0.150s,  427.75/s)  Loss:  0.5044 (0.8655)  Acc@1:  85.938 ( 80.360)  Acc@5: 100.000 ( 95.145)
Test: [ 600/782]  Time: 0.034s (0.149s,  429.50/s)  Loss:  0.7329 (0.8806)  Acc@1:  84.375 ( 80.033)  Acc@5:  95.312 ( 94.930)
Test: [ 650/782]  Time: 0.035s (0.148s,  431.66/s)  Loss:  0.3967 (0.8937)  Acc@1:  92.188 ( 79.738)  Acc@5:  96.875 ( 94.732)
Test: [ 700/782]  Time: 0.034s (0.148s,  431.70/s)  Loss:  0.5571 (0.9045)  Acc@1:  87.500 ( 79.271)  Acc@5:  95.312 ( 94.648)
Test: [ 750/782]  Time: 0.034s (0.149s,  430.46/s)  Loss:  2.1934 (0.9040)  Acc@1:  56.250 ( 79.232)  Acc@5:  81.250 ( 94.659)
 * Acc@1 79.166 (20.834) Acc@5 94.686 (5.314)

If not, you can eval with other models to see if it's the problem about the dataset.

zihangJiang commented 2 years ago

@zihangJiang Thanks for the information. The LV-ViT-T model initialized with the provided checkpoint only achieves an accuracy of 57%, am I missing something?

And you should also check the model config for lvvit_t here, https://github.com/zihangJiang/TokenLabeling/blob/35815d929579a797f1fbe52fa7f8dff13950f7d1/tlt/models/lvvit.py#L239-L244 I guess you may set the wrong config which will largely affect the performance.

marc345 commented 2 years ago

@zihangJiang Thanks for your help, dataset, checkpoint, and config are not the problem, I have to check my model implementation again.