juhongm999 / pervit

Official PyTorch Implementation of Peripheral Vision Transformer, NeurIPS 2022
Apache License 2.0
40 stars 2 forks source link
computer-vision image-classification representation-learning

Peripheral Vision Transformer

This is the implementation of the paper "Peripheral Vision Transformer" by Juhong Min, Yucheng Zhao, Chong Luo, and Minsu Cho. Implemented on Python 3.7 and Pytorch 1.8.1.

For more information, check out project [website] and the paper on [arXiv].

Requirements

conda create -n pervit python=3.7
conda activate pervit

conda install pytorch=1.8.1 torchvision=0.9.1 cudatoolkit=10.1 -c pytorch
conda install -c conda-forge tensorflow
conda install -c conda-forge matplotlib
pip install timm==0.3.2
pip install tensorboardX
pip install einops
pip install ptflops

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/ (ILSVRC2012). The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Training

To train PerViT-{T, S, M} on ImageNet-1K on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8
                                   --use_env main.py
                                   --batch-size 128
                                   --model pervit_{tiny, small, medium}
                                   --weight-decay {0.03, 0.05, 0.05}
                                   --warmup-epochs {5, 5, 20}
                                   --drop-path {0.0, 0.1, 0.2}
                                   --data-set IMNET
                                   --data-path /path/to/imagenet
                                   --output_dir /path/to/output_dir

Evaluation

To evaluate PerViT-{T, S, M} on ImageNet-1K test set, run:

python main.py --eval --pretrained 
               --model pervit_{tiny, small, medium}
               --data-set IMNET
               --data-path /path/to/imagenet
               --load /path/to/pretrained_model
               --output_dir /path/to/output_dir
  • Pretrained PerViT-{T, S, M} is available at this [link].

Learned attention visualization

To visualize learned attention map, e.g., position-based attention, in evaluation, add argument --visualize (the images will be saved under vis/ directory):

python main.py --eval --pretrained '...other arguments...' --visualize  

The learned position-based attentions (sorted in the order of nonlocality) will be visualized as follows

Acknowledgement

We mainly borrow code from public project of [ConViT] and [DeiT].

BibTeX

If you use this code for your research, please consider citing:

@InProceedings{min2022pervit,
    author = {Juhong Min and Yucheng Zhao and Chong Luo and Minsu Cho},
    booktitle = {Advances in Neural Information Processing Systems},
    title = {{Peripheral Vision Transformer}},
    year = {2022}
}

License

The majority of this repository is released under the Apache 2.0 license as found in the LICENSE file.