This repository contains code for the paper "Revisiting Token Pruning for Object Detection and Instance Segmentation ".
Vision Transformers (ViTs) have shown impressive performance in computer vision, but their high computational cost, quadratic in the number of tokens, limits their adoption in computation-constrained applications. However, this large number of tokens may not be necessary, as not all tokens are equally important. In this paper, we investigate token pruning to accelerate inference for object detection and instance segmentation, extending prior works from image classification. Through extensive experiments, we offer four insights for dense tasks: (i) tokens should not be completely pruned and discarded, but rather preserved in the feature maps for later use. (ii) reactivating previously pruned tokens can further enhance model performance. (iii) a dynamic pruning rate based on images is better than a fixed pruning rate. (iv) a lightweight, 2-layer MLP can effectively prune tokens, achieving accuracy comparable with complex gating networks with a simpler design. We evaluate the impact of these design choices on COCO dataset and present a method integrating these insights that outperforms prior art token pruning models, significantly reducing performance drop from ~1.5 mAP to ~0.3 mAP for both boxes and masks. Compared to the dense counterpart that uses all tokens, our method achieves up to 34% faster inference speed for the whole network and 46% for the backbone.
recommended environment:
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install mmcv-full==1.7.0 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13/index.html
pip install timm==0.4.12
pip install mmdet==2.28.1
pip install scipy
cd ops & sh make.sh # compile deformable attention
Please prepare COCO according to the guidelines in MMDetection.
Alternatively, download download_dataset.py and run python download_dataset.py --dataset-name coco2017
.
The dataset should have (or symlinked to have) the following folder structure:
root_folder
├── mmcv_custom
├── mmdet_custom
├── configs
├── ops
├── data
│ ├── coco
│ │ ├── annotations
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
The following models exploit Mask R-CNN and use ViT-Adapter as backbones, which adapt DeiT without windows in this repo, since token pruning is incompatible with windows. Token pruning is introduced in the fintuing after the dense model has been trained for 36 epochs, and the finetuning contains 4 or 6 epochs.
Backbone | Pre-train | Lr schd | box AP | mask AP | #Param | FPS | Config | Download | Logs |
---|---|---|---|---|---|---|---|---|---|
ViT-Adapter-T | DeiT-T | 3x+MS | 45.8 | 40.9 | 28M | 18.45 | config | model | log |
SViT-Adapter-T | ViT-Adapter-T | 0.5x+MS | 45.5 | 40.7 | 28M | 22.32 | config | model | log |
ViT-Adapter-S | DeiT-S | 3x+MS | 48.5 | 42.8 | 48M | 11.70 | config | model | log |
SViT-Adapter-S | ViT-Adapter-S | 0.33x+MS | 48.2 | 42.5 | 48M | 15.75 | config | model | log |
To evaluate SViT-Adapter-S on COCO val2017 on a single node with 8 gpus run:
sh dist_test.sh configs/mask_rcnn/svit-adapter-s-0.33x-ftune.py pretrained/svit-adapter-s-0.33x.pth 8 --eval bbox segm
To train a dense ViT-Adapter-T with global attention (Mask R-CNN) on COCO train2017 on a single node with 4 gpus for 36 epochs run:
sh dist_train.sh configs/mask_rcnn/vit-adapter-t-3x.py 4
To train a dense ViT-Adapter-S with global attention (Mask R-CNN) on COCO train2017 on a single node with 8 gpus for 36 epochs run:
sh dist_train.sh configs/mask_rcnn/vit-adapter-s-3x.py 8
The number of gpus x samples_per_gpu
from the config file should be equal to 16.
To finetune the sparse SViT-Adapter-T with pruned tokens (Mask R-CNN) on COCO train2017 on a single node with 4 gpus for 6 epochs run:
sh dist_train.sh configs/mask_rcnn/svit-adapter-t-0.5x-ftune.py 4
To finetune the sparse SViT-Adapter-S with pruned tokens (Mask R-CNN) on COCO train2017 on a single node with 8 gpus for 4 epochs run:
sh dist_train.sh configs/mask_rcnn/svit-adapter-s-0.33x-ftune.py 8
The number of gpus x samples_per_gpu
from the config file should be equal to 16.
We provide the script to compare the models' speeds:
python speed_test.py
We provide the script to visualize the token pruning process:
python seletor_demo.py data/coco/val2017/000000046252.jpg configs/mask_rcnn/demo-svit-adapter-s-0.33x-ftune.py pretrained/svit-adapter-s-0.33x.pth
If this work is helpful for your research, please consider citing the following BibTex entry:
@InProceedings{Liu_2024_WACV,
title={Revisiting Token Pruning for Object Detection and Instance Segmentation},
author={Liu, Yifei and Gehrig, Mathias and Messikommer, Nico and Cannici, Marco and Scaramuzza, Davide},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
year={2024}
}
This repository is released under the Apache 2.0 license as found in the LICENSE file.
This project has used code from the following projects: