NVlabs / A-ViT

Official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer (CVPR 2022)
Apache License 2.0
138 stars 12 forks source link

Python 3.6

A-ViT: Adaptive Tokens for Efficient Vision Transformer

This repository is the official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer presented at CVPR 2022.

The code enables adaptive inference of tokens for vision transformers. Current version includes training, evaluation, and visualization of models on ImageNet-1K. We plan to update repository and release run-time token zipping for inference, distillation models, and base models in coming versions.

Useful links:
[Camera ready]
[ArXiv Full PDF]
[Project page]

Teaser

Requirements

Code was tested in virtual environment with Python 3.6. Install requirements as in the requirements.txt file.

Commands

Data Preparation

Please prepare the ImageNet dataset into the following structure:

  imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...

Training an A-ViT on ImageNet-1K

This snippet will support the training of A-ViT on ImageNet-1K. Please place the ImageNet dataset accordingly as above. For starting point of pretrained DEIT weights the code shall automatically load as below with --pretrained key. Please refer to here or in the models_act.py code URLs if facing loading or downloading issues.

Code is tested on NVIDIA 4-V100 GPUs cluster of 32GB memory. Training takes 100 epochs.

For tiny model:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main_act.py --model avit_tiny_patch16_224 --data-path <data to imagenet folder> --output_dir ./results/<name your exp for tensorboard files and ckpt> --pretrained --batch-size 128 --lr 0.0005 --tensorboard --epochs 100 --gate_scale 10.0 --gate_center 30 --warmup-epochs 5 --ponder_token_scale 0.0005 --distr_prior_alpha 0.001

For small model:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main_act.py --model avit_small_patch16_224 --data-path <data to imagenet folder> --output_dir ./results/<name your exp name> --pretrained --batch-size 96 --lr 0.0003 --tensorboard --epochs 100 --gate_scale 10.0 --gate_center 75 --warmup-epochs 5 --ponder_token_scale 0.0005 --distr_prior_alpha 0.001

Arguments:

Pretrained Weights on ImageNet-1K

Pretrained A-ViT model weights can be downloaded here.

Use the following command to unzip it into the main repository into folder ./a-vit-weights such that loading is supported directly in the training file.

tar -xzf <path to your downloaded folder>/a-vit-weights.tar.gz ./a-vit-weights
Name Acc@1 Acc@5 Resolution #Params (M) FLOPs (G) Path
A-ViT-T 71.4 90.4 224x224 5 0.8 a-vit-weights/tiny-10-30.pth
A-ViT-S 78.8 93.9 224x224 22 3.6 a-vit-weights/small-10-75.pth

For evaluation, simply append the following snippet to the training script with nproc_per_node set as 1:

--eval --finetune <path to ckpt folder>/{name of chekpoint}

Visualization of Token Depths

To give a quickly visualization of the learnt adaptive inference, kindly download pretrained weights as instructed, and raise the --demo.

This will plot the token depth distribution in a new folder named token_act_visualization, with one demo image per 1K validation classes, saving (i) original, (ii) token depth, and (iii) concatenated left-to-right comparison images. A full run of code generates 3K images.

Batch size is set as 50 for per-class analysis. Note that this function is not fully optimized and can be slow.

python -m torch.distributed.launch --nproc_per_node=1 --use_env main_act.py --model avit_tiny_patch16_224 --data-path <data to imagenet folder> --finetune <path to ckpt folder>/visualization-tiny.pth --demo

Some examples (Left - ImageNet validation image 224x224, unseen during training |Right - token depth, whiter is deeper | Legend - depth, 12 is full depth.) More examples in example_images folder.

Please feel free to generate more examples using the --demo key, or more via visualize.py function in engine_act.py. Note that the snippet is a very quick demo from an intermediate checkpoint of tiny a-vit. We observe the distribution will continue to slightly converge upon higher accuracy. Analyzing other checkpoint is very easily supported by changing the loading of .pth files, and distribution semantic meaning holds.

License

Copyright (C) 2022 NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit the LICENSE file in this repository.

The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

For license information regarding the timm repository, please refer to the official website.

For license information regarding the DEIT repository, please refer to the official website.

For license information regarding the ImageNet dataset, please refer to the official website.

Citation

@InProceedings{Yin_2022_CVPR,
    author    = {Yin, Hongxu and Vahdat, Arash and Alvarez, Jose M. and Mallya, Arun and Kautz, Jan and Molchanov, Pavlo},
    title     = {{A}-{V}i{T}: {A}daptive Tokens for Efficient Vision Transformer},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {10809-10818}
}