NVlabs / FAN

Official PyTorch implementation of Fully Attentional Networks
https://arxiv.org/abs/2204.12451
Other
467 stars 28 forks source link
backbone cityscapes coco corruption deep-learning image-classification imagenet information-bottleneck object-detection out-of-distribution pre-train self-attention semantic-segmentation vision-transformers visual-grouping visual-recognition

Fully Attentional Networks

PWC PWC PWC

Project Page | Paper | Slides | Poster

Understanding The Robustness in Vision Transformers. \ Daquan Zhou, Zhiding Yu, Enze Xie, Chaowei Xiao, Anima Anandkumar, Jiashi Feng and Jose M. Alvarez. \ International Conference on Machine Learning, 2022.

This repository contains the official Pytorch implementation of the training/evaluation code and the pretrained models of Fully Attentional Network (FAN).

FAN is a family of general-purpose Vision Transformer backbones that are highly robust to unseen natural corruptions in various visual recognition tasks.

Catalog

Dependencies

The repo is built based on timm library, which can be installed via: pip3 install timm==0.5.4 pip3 install torchvision==0.9.0

Dataset preparation

Download ImageNet clean dataset and ImageNet-C dataset and structure the datasets as follows:

/path/to/imagenet-C/
  clean/
    class1/
      img3.jpeg
    class2/
      img4.jpeg
  corruption1/
    severity1/
      class1/
        img3.jpeg
      class2/
        img4.jpeg
    severity2/
      class1/
        img3.jpeg
      class2/
        img4.jpeg

For other out-of-distribution shift benchmarks, we use ImageNet-A or ImageNet-R for evaluation.

Results and Pre-trained Models

FAN-ViT ImageNet-1K trained models

Model Resolution IN-1K IN-C IN-A IN-R #Params Download
FAN-T-ViT 224x224 79.2 57.5 15.6 42.5 7.3M model
FAN-S-ViT 224x224 82.5 64.5 29.1 50.4 28.0M model
FAN-B-ViT 224x224 83.6 67.0 35.4 51.8 54.0M model
FAN-L-ViT 224x224 83.9 67.7 37.2 53.1 80.5M [model]()

FAN-Hybrid ImageNet-1K trained models

Model Resolution IN-1K / IN-C City / City-C COCO / COCO-C #Params Download
FAN-T-Hybrid 224x224 80.1/57.4 81.2/57.1 50.2/33.1 7.4M model
FAN-S-Hybrid 224x224 83.5/64.7 81.5/66.4 53.3/38.7 26.3M model
FAN-B-Hybrid 224x224 83.9/66.4 82.2/66.9 54.2/40.6 50.4M model
FAN-L-Hybrid 224x224 84.3/68.3 82.3/68.7 55.1/42.0 76.8M [model]()

FAN-Hybrid ImageNet-22K trained models

Model Resolution IN-1K/IN-C #Params Download
FAN-B-Hybrid 224x224 85.3/70.5 50.4M model
FAN-B-Hybrid 384x384 85.6/- 50.4M model
FAN-L-Hybrid 224x224 86.5/73.6 76.8M model
FAN-L-Hybrid 384x384 87.1/- 76.8M model

The pre-trained model weights for FAN-B-Hybrid and FAN-L-Hybrid on ImageNet22K without fine-tuning on ImageNet-1k are also uploaded. Checkpoints cabn be downloaded by clicking on the model name.

Demos

Semantic Segmentation on Cityscapes-C

animated

ImageNet-1K Training

FAN-T training on ImageNet-1K with 4 8-GPU nodes:

python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=$rank_num \
    --node_rank=$rank_index --master_addr="ip.addr" --master_port=$MASTER_PORT \
     main.py  /PATH/TO/IMAGENET/ --model fan_tiny_8_p4_hybrid -b 32 --sched cosine --epochs 300 \
    --opt adamw -j 16 --warmup-epochs 5  \
    --lr 10e-4 --drop-path .1 --img-size 224 \
    --output ../fan_tiny_8_p4_hybrid/ \
    --amp --model-ema \

Robustness on ImageNet-C

bash scripts/imagenet_c_val.sh $model_name $ckpt

Measurement on ImageNet-A

bash scripts/imagenet_a_val.sh $model_name $ckpt

Measurement on ImageNet-R

bash scripts/imagenet_r_val.sh $model_name $ckpt

Acknowledgement

This repository is built using the timm library, DeiT, PVT and SegFormer repositories.

Citation

If you find this repository helpful, please consider citing:

@inproceedings{zhou2022understanding,
  title   = {Understanding The Robustness in Vision Transformers},
  author  = {Daquan Zhou, Zhiding Yu, Enze Xie, Chaowei Xiao, Anima Anandkumar, Jiashi Feng, Jose M. Alvarez},
  booktitle = {International Conference on Machine Learning (ICML)},
  year    = {2022},
}