Fangyi-Chen / SQR

MIT License
103 stars 5 forks source link

Enhanced Training of Query-Based Object Detection via Selective Query Recollection arxiv

Enhanced Training of Query-Based Object Detection via Selective Query Recollection
Fangyi Chen, Han Zhang, Kai Hu, Yu-Kai Huang, Chenchen Zhu, Marios Savvides
Carnegie Mellon University, Meta AI

📰 News

2023.07 We support SQR-DAB-DETR at detrex codebase.\ 2023.06 We fixed an issue in the inference of SQR-Deformable DETR, which logically exists but does not impact the final results.\ 2023.03 This work has been accepted by CVPR 2023.\ 2023.03 The experiments and code on SQR-adamixer and SQR-Deformable DETR have been released.\ 2022.12 The code is available now.

🤔 Motivation

🌧 One phenomenon where query-based object detectors mispredict at the last decoding stage but correctly predict at intermediate stages.

The decoding procedure of DETR implies that detection should be stage-by-stage enhanced in terms of IOU and confidence score. Indeed, monotonically improved AP is empirically achieved by this procedure. However, when visualizing the stage-wise predictions, we surprisingly observe that decoder makes mistakes in a decent proportion of cases where the later stages degrade true- positives and upgrade false-positives from the former stages.

⭕ Two limitations of training

  1. The responsibility that each stage takes is unbalanced, while supervision applied to them is analogous.
  2. Due to the sequential structure of the decoder, an intermediate query refined by a stage - no matter this refinement brings positive or negative effects - will be cascaded to the following stage

🚀 Selective Query Recollection

As a training strategy that fit most query-based object detectors (DETR family), SQR cumulatively collects intermediate queries as stages go deeper, and feeds the collected queries to the downstream stages aside from the sequential structure.

➡️ Guide to Code

This repo provide the implementation of SQR-Adamixer and SQR-deformable DETR. The code structure follows the MMDetection framework. Adamixer is a typical query-based object detector that enjoys fast convergence and high AP performance. Deformable DETR is known for its creative deformable attention module that mitigates the slow convergence and high complexity issues of DETR.

Config

Our config file lies in configs/sqr folder.

SQR-Adamixer

We provide two implementation instances of SQR-adamixer in this repo, one is in /mmdet/models/roi_heads/adamixer_decoder_Qrecycle.py, which might be slower for training but require less GPU memory (and easy to understand the logic). Another is in /mmdet/models/roi_heads/adamixer_decoder_Qrecycle_optimize.py, which is much faster than the former (and highly recommended for using) but has higher requirement on GPU memory.

SQR-Deformable DETR

Similarly, We provide two implementation instances of SQR-deformable DETR in QRDeformableDetrTransformerDecoder in /mmdet/models/utils/transformer.py. Named as forward and forward_slow, separately. Please also check this issue and the QRDeformableDETRHead in /mmdet/models/dense_heads/QR_deformable_detr_head.py where the recollected query is aligned with stages during training. Please note that SQR is only a training strategy that does not affect/change testing pipeline.

Installation

We test our models under python=3.7, pytorch=1.9.1, cuda=11.1, mmcv=1.3.3.

NOTE: Please use mmcv_full==1.3.3 and pytorch>=1.5.0 for correct reproduction.

  1. Clone this repo

    git clone https://github.com/Fangyi-Chen/SQR.git
    cd SQR
  2. Create a conda env and activate it

    conda create -n sqr-detr python=3.7 -y
    conda activate sqr-detr
  3. Install Pytorch and torchvision

Follow the instruction on https://pytorch.org/get-started/locally/.

# an example:
conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.1 -c pytorch -c conda-forge
  1. Install mmcv

    pip install mmcv-full=1.3.3 --no-cache-dir
  2. Install mmdet

    pip install -r requirements/build.txt
    pip install -v -e .  # or "python setup.py develop"

Getting Started

Please see get_started.md for the basic usage of MMDetection.

🧪 Main Results

#q AP AP50 AP75 APs APm APl model cfg
SQR-Adamixer-R50 100 44.5 63.2 47.8 25.7 47.4 60.2 ckpt cfg
SQR-Adamixer-R101-7stages 300 49.8 68.8 54.0 32.0 53.4 65.1 ckpt cfg
SQR-Deformable-DETR 300 45.8 64.7 49.8 28.2 49.4 60.0 ckpt cfg

✏️ Citation

If you find SQR useful, please use the following entry to cite us:

@InProceedings{Chen_2023_CVPR,
    author    = {Chen, Fangyi and Zhang, Han and Hu, Kai and Huang, Yu-Kai and Zhu, Chenchen and Savvides, Marios},
    title     = {Enhanced Training of Query-Based Object Detection via Selective Query Recollection},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {23756-23765}
}

Original MMDetection README.md

The following begins the original mmdetection README.md file

News: We released the technical report on ArXiv.

Documentation: https://mmdetection.readthedocs.io/

Introduction

English | 简体中文

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

The master branch works with PyTorch 1.3+. The old v1.x branch works with PyTorch 1.1 to 1.4, but v2.0 is strongly recommended for faster speed, higher performance, better design and more friendly usage.

demo image

Major features

Apart from MMDetection, we also released a library mmcv for computer vision research, which is heavily depended on by this toolbox.

License

The mmdetection project is released under the Apache 2.0 license.

Changelog

v2.12.0 was released in 01/05/2021. Please refer to changelog.md for details and release history. A comparison between v1.x and v2.0 codebases can be found in compatibility.md.

Benchmark and model zoo

Results and models are available in the model zoo.

Supported backbones:

Supported methods:

Some other methods are also supported in projects using MMDetection.

Installation

Please refer to get_started.md for installation.

Getting Started

Please see get_started.md for the basic usage of MMDetection. We provide colab tutorial, and full guidance for quick run with existing dataset and with new dataset for beginners. There are also tutorials for finetuning models, adding new dataset, designing data pipeline, customizing models, customizing runtime settings and useful tools.

Please refer to FAQ for frequently asked questions.

Contributing

We appreciate all contributions to improve MMDetection. Please refer to CONTRIBUTING.md for the contributing guideline.

Acknowledgement

MMDetection is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks. We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new detectors.

Citation

If you use this toolbox or benchmark in your research, please cite this project.

@article{mmdetection,
  title   = {{MMDetection}: Open MMLab Detection Toolbox and Benchmark},
  author  = {Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and
             Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and
             Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and
             Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and
             Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong
             and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua},
  journal= {arXiv preprint arXiv:1906.07155},
  year={2019}
}

Projects in OpenMMLab