LeapLabTHU / FLatten-Transformer

Official repository of FLatten Transformer (ICCV2023)
377 stars 21 forks source link

FLatten Transformer

This repo contains the official PyTorch code and pre-trained models for FLatten Transformer (ICCV 2023).

Updates

Introduction

Motivation

The quadratic computation complexity of self-attention $\mathcal{O}(N^2)$ has been a long-standing problem when applying Transformer models to vision tasks. Apart from reducing attention regions, linear attention is also considered as an effective solution to avoid excessive computation costs. By approximating Softmax with carefully designed mapping functions, linear attention can switch the computation order in the self-attention operation and achieve linear complexity $\mathcal{O}(N)$. Nevertheless, current linear attention approaches either suffer from severe performance drop or involve additional computation overhead from the mapping function. In this paper, we propose a novel Focused Linear Attention module to achieve both high efficiency and expressiveness.

Method

In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our Focused Linear Attention (FLatten) which adequately addresses these concerns and achieves high efficiency and expressive capability.

Results

Dependencies

Data preparation

The ImageNet dataset should be prepared as follows:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Pretrained Models

Based on different model architectures, we provide several pretrained models, as listed below.

model Reso acc@1 config pretrained weights
FLatten-PVT-T $224^2$ 77.8 (+2.7) config TsinghuaCloud
FLatten-PVTv2-B0 $224^2$ 71.1 (+0.6) config TsinghuaCloud
FLatten-Swin-T $224^2$ 82.1 (+0.8) config TsinghuaCloud
FLatten-Swin-S $224^2$ 83.5 (+0.5) config TsinghuaCloud
FLatten-Swin-B $224^2$ 83.8 (+0.3) config TsinghuaCloud
FLatten-Swin-B $384^2$ 85.0 (+0.5) config TsinghuaCloud
FLatten-CSwin-T $224^2$ 83.1 (+0.4) config TsinghuaCloud

Evaluate one model on ImageNet:

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>

Outputs of the four T/B0 pretrained models are:

[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO  * Acc@1 77.758 Acc@5 93.910
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8%

[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO  * Acc@1 71.098 Acc@5 90.596
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1%

[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO  * Acc@1 82.106 Acc@5 95.900
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1%

[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO  * Acc@1 83.130 Acc@5 96.376
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1%

Train Models from Scratch

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99982

Fine-tuning on higher resolution

Fine-tune a FLatten-Swin-B model pre-trained on 224x224 resolution to 384x384 resolution:

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>

Fine-tune a FLatten-CSwin-B model pre-trained on 224x224 resolution to 384x384 resolution:

python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.99982

Visualization

We provide code for visualizing flatten attention. For example, to visualize flatten attention in FLatten-Swin-T, add the following to this line.

from visualize import AttnVisualizer
visualizer = AttnVisualizer(qk=[q, k], kernel=self.dwc.weight, name='flatten_swin_t')
visualizer.visualize_all_attn(max_num=196, image='./visualize/img_ori_00809.png')

Then run:

python visualize.py

Note: Don't forget to modify the path of FLatten-Swin-T pretrained weight in visualize.py.

Acknowledgements

This code is developed on the top of Swin Transformer. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd

Citation

If you find this repo helpful, please consider citing us.

@InProceedings{han2023flatten,
  title={FLatten Transformer: Vision Transformer using Focused Linear Attention},
  author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year={2023}
}

Contact

If you have any questions, please feel free to contact the authors.

Dongchen Han: hdc23@mails.tsinghua.edu.cn

Xuran Pan: pxr18@mails.tsinghua.edu.cn