haofeixu / aanet

[CVPR'20] AANet: Adaptive Aggregation Network for Efficient Stereo Matching
Apache License 2.0
521 stars 100 forks source link
cost-aggregation cost-volume cvpr2020 deformable-convolution edge-preserving stereo-matching stereo-vision

AANet

PyTorch implementation of our paper:

AANet: Adaptive Aggregation Network for Efficient Stereo Matching, CVPR 2020

Authors: Haofei Xu and Juyong Zhang

11/15/2022 Update: Check out our new work: Unifying Flow, Stereo and Depth Estimation and code: unimatch for performing stereo matching with our new GMStereo model. The CUDA op in AANet is no longer required. 10 pretrained GMStereo models with different speed-accuracy trade-offs are also released. Check out our Colab and HuggingFace demo to play with GMStereo in your browser!

We propose a sparse points based intra-scale cost aggregation (ISA) module and a cross-scale cost aggregation (CSA) module for efficient and accurate stereo matching.

The implementation of improved version AANet+ (stronger performance & slightly faster speed) is also included in this repo.

Highlights

Installation

Our code is based on PyTorch 1.2.0, CUDA 10.0 and python 3.7.

We recommend using conda for installation:

conda env create -f environment.yml

After installing dependencies, build deformable convolution:

cd nets/deform_conv && bash build.sh

Dataset Preparation

Download Scene Flow, KITTI 2012 and KITTI 2015 datasets.

Our folder structure is as follows:

data
├── KITTI
│   ├── kitti_2012
│   │   └── data_stereo_flow
│   ├── kitti_2015
│   │   └── data_scene_flow
└── SceneFlow
    ├── Driving
    │   ├── disparity
    │   └── frames_finalpass
    ├── FlyingThings3D
    │   ├── disparity
    │   └── frames_finalpass
    └── Monkaa
        ├── disparity
        └── frames_finalpass

If you would like to use the pseudo ground truth supervision introduced in our paper, you can download the pre-computed disparity on KITTI 2012 and KITTI 2015 training set here: KITTI 2012, KITTI 2015.

For KITTI 2012, you should place the unzipped file disp_occ_pseudo_gt under kitti_2012/data_stereo_flow/training directory.

For KITTI 2015, you should place disp_occ_0_pseudo_gt under kitti_2015/data_scene_flow/training.

It is recommended to symlink your dataset root to $AANET/data:

ln -s $YOUR_DATASET_ROOT data

Otherwise, you may need to change the corresponding paths in the scripts.

Model Zoo

All pretrained models are available in the model zoo.

We assume the downloaded weights are located under the pretrained directory.

Otherwise, you may need to change the corresponding paths in the scripts.

Inference

To generate prediction results on the test set of Scene Flow and KITTI dataset, you can run scripts/aanet_inference.sh.

The inference results on KITTI dataset can be directly submitted to the online evaluation server for benchmarking.

Prediction

We also support predicting on any rectified stereo pairs. scripts/aanet_predict.sh provides an example usage.

Training

All training scripts on Scene Flow and KITTI datasets are provided in scripts/aanet_train.sh.

Note that we use 4 NVIDIA V100 GPUs (32G) with batch size 64 for training, you may need to tune the batch size according to your hardware.

We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with

tensorboard --logdir checkpoints

and then access http://localhost:6006 in your browser.

Evaluation

To enable fast experimenting, evaluation runs on-the-fly without saving the intermediate results.

We provide two types of evaluation setting:

Check scripts/aanet_evaluate.sh for an example usage.

Citation

If you find our work useful in your research, please consider citing our paper:

@inproceedings{xu2020aanet,
  title={AANet: Adaptive Aggregation Network for Efficient Stereo Matching},
  author={Xu, Haofei and Zhang, Juyong},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={1959--1968},
  year={2020}
}

Acknowledgements

Part of the code is adopted from previous works: PSMNet, GwcNet and GA-Net. We thank the original authors for their awesome repos. The deformable convolution op is taken from mmdetection. The FLOPs counting code is modified from pytorch-OpCounter. The code structure is partially inspired by mmdetection and our previous work rdn4depth.