This repo contains the official PyTorch code and pre-trained models for FLatten Transformer (ICCV 2023).
The quadratic computation complexity of self-attention $\mathcal{O}(N^2)$ has been a long-standing problem when applying Transformer models to vision tasks. Apart from reducing attention regions, linear attention is also considered as an effective solution to avoid excessive computation costs. By approximating Softmax with carefully designed mapping functions, linear attention can switch the computation order in the self-attention operation and achieve linear complexity $\mathcal{O}(N)$. Nevertheless, current linear attention approaches either suffer from severe performance drop or involve additional computation overhead from the mapping function. In this paper, we propose a novel Focused Linear Attention module to achieve both high efficiency and expressiveness.
In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our Focused Linear Attention (FLatten) which adequately addresses these concerns and achieves high efficiency and expressive capability.
The ImageNet dataset should be prepared as follows:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
Based on different model architectures, we provide several pretrained models, as listed below.
model | Reso | acc@1 | config | pretrained weights |
---|---|---|---|---|
FLatten-PVT-T | $224^2$ | 77.8 (+2.7) | config | TsinghuaCloud |
FLatten-PVTv2-B0 | $224^2$ | 71.1 (+0.6) | config | TsinghuaCloud |
FLatten-Swin-T | $224^2$ | 82.1 (+0.8) | config | TsinghuaCloud |
FLatten-Swin-S | $224^2$ | 83.5 (+0.5) | config | TsinghuaCloud |
FLatten-Swin-B | $224^2$ | 83.8 (+0.3) | config | TsinghuaCloud |
FLatten-Swin-B | $384^2$ | 85.0 (+0.5) | config | TsinghuaCloud |
FLatten-CSwin-T | $224^2$ | 83.1 (+0.4) | config | TsinghuaCloud |
Evaluate one model on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
Outputs of the four T/B0 pretrained models are:
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO * Acc@1 77.758 Acc@5 93.910
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8%
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO * Acc@1 71.098 Acc@5 90.596
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1%
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO * Acc@1 82.106 Acc@5 95.900
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1%
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO * Acc@1 83.130 Acc@5 96.376
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1%
FLatten-PVT-T/S/M/B
on ImageNet from scratch, run:python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
FLatten-PVT-v2-b0/1/2/3/4
on ImageNet from scratch, run:python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path <imagenet-path> --output <output-path>
FLatten-Swin-T/S/B
on ImageNet from scratch, run:python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path <imagenet-path> --output <output-path>
FLatten-CSwin-T/S/B
on ImageNet from scratch, run:python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99982
Fine-tune a FLatten-Swin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>
Fine-tune a FLatten-CSwin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.99982
We provide code for visualizing flatten attention. For example, to visualize flatten attention in FLatten-Swin-T, add the following to this line.
from visualize import AttnVisualizer
visualizer = AttnVisualizer(qk=[q, k], kernel=self.dwc.weight, name='flatten_swin_t')
visualizer.visualize_all_attn(max_num=196, image='./visualize/img_ori_00809.png')
Then run:
python visualize.py
Note: Don't forget to modify the path of FLatten-Swin-T pretrained weight in visualize.py
.
This code is developed on the top of Swin Transformer. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd
If you find this repo helpful, please consider citing us.
@InProceedings{han2023flatten,
title={FLatten Transformer: Vision Transformer using Focused Linear Attention},
author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}
If you have any questions, please feel free to contact the authors.
Dongchen Han: hdc23@mails.tsinghua.edu.cn
Xuran Pan: pxr18@mails.tsinghua.edu.cn