sandipan211 / ZSD-SC-Resolver

Resolving semantic confusions for improved zero-shot detection (BMVC 2022)
MIT License
22 stars 4 forks source link
computer-vision conditional-gan deep-learning faster-rcnn multi-modal-learning object-detection pytorch-implementation triplet-loss zero-shot-learning zero-shot-object-detection

PWC

PWC

PWC

Resolving Semantic Confusions for Improved Zero-Shot Detection (BMVC Oral Presentation, 2022)

:eyeglasses: At a glance

This repository contains the official PyTorch implementation of our BMVC 2022 paper : Resolving Semantic Confusions for Improved Zero-Shot Detection, a work done by Sandipan Sarma, Sushil Kumar and Arijit Sur at Indian Institute of Technology Guwahati.

method
The primary novelty of our model lies in the incorporation of triplet loss based on visual features, assisted by a cyclic-consistency loss

:newspaper: News

:video_camera: Video

This paper was presented at the BMVC Orals, 2022.

:bullettrain_side: Training the model

1. :office: Creating the work environment

Our code is based on PyTorch and has been implemented using an NVIDIA V100 32 GB DGX Station, with mmdetection as the base framework for object detection, which contains a Faster-RCNN implementation. Install Anaconda/Miniconda on your system and create a conda environment using the following command:

conda env create -f zsd_environment.yml

Once set up, activate the environment and do the following:

cd ./mmdetection/

# install mmdetection and bind it to your project
python setup.py develop

Following commands are being shown for MSCOCO dataset. For PASCAL-VOC dataset, make the appropriate changes to the command line arguments and run the appropriate scripts.

2. :hourglass_flowing_sand: Train Faster-RCNN detector on seen data

All the configurations regarding training and testing pipelines are stored in a configuration file. To access it and make changes in it, find the file using:

cd ./mmdetection/configs/faster_rcnn_r101_fpn_1x.py

In zero-shot detection, the object categories in a dataset are split into two sets - seen and unseen. Such sets are defined in previous works for both MSCOCO [1] and PASCAL-VOC [2] datasets. The splits can be found in splits.py.

To train the Faster-RCNN on seen data, run:

cd ./mmdetection
./tools/dist_train.sh configs/faster_rcnn_r101_fpn_1x.py 1 --validate 

For reproducibility, it is recommended to use the pre-trained model given below in this repository. It is important to create a directory named work_dirs inside mmdetection folder, where there should be separate directories for MSCOCO and PASCAL-VOC, inside which the weights of the trained Faster-RCNN should be stored. For our pre-trained models, we name them as epoch_12.pth and epoch_4.pth after training Faster-RCNN on seen data of MSCOCO and PASCAL-VOC datasets respectively.

The pre-trained weights of Faster-RCNN are stored with the ResNet-101 (backbone CNN) being pre-trained only after removing the overlapping classes from ImageNet [3]. This pre-trained ResNet is given here, and weights of Faster-RCNN are uploaded both for PASCAL-VOC and MSCOCO.

3. :outbox_tray: Extract object features

Inside the data folder, MSCOCO and PASCAL-VOC image datasets should be stored in appropriate formats, before running the following:

cd ./mmdetection
python tools/zero_shot_utils.py configs/faster_rcnn_r101_fpn_1x.py --classes seen --load_from ./work_dirs/coco2014/epoch_12.pth --save_dir ./data --data_split train
python tools/zero_shot_utils.py configs/faster_rcnn_r101_fpn_1x.py --classes unseen --load_from ./work_dirs/coco2014/epoch_12.pth --save_dir ./data --data_split test

4. :left_right_arrow: Training a visual-semantic mapper

Train a visual-semantic mapper using the seen data to learn a function mapping visual-space to semantic space. This trained mapper would be used in the next step while computing cyclic-consistency loss, improving feature-synthesis quality of GAN. Run:

python train_regressor.py 

Weights will be saved in the appropriate paths. For VOC, run train_regressor_voc.py

5. :factory: Train the generative model using extracted features

Extracted seen-class object features constitute the real data distribution, using which a Conditional Wasserstein GAN is trained, with class-semantics of seen/unseen classes acting as the conditional variables. During GAN training, triplet loss is computed based on the synthesized object features, enforcing inter-class dissimilarity learning. Moreover, a cyclic-consistency between the synthesized features and their class semantics is computed, encourgaing the GAN to generate visual features that correspond well to their own semantics. For training the GAN, run the script:

./script/train_coco_generator_65_15.sh

6. :mag: Evaluation

cd mmdetection

#evaluation on zsd
./tools/dist_test.sh configs/faster_rcnn_r101_fpn_1x.py ./work_dirs/coco2014/epoch_12.pth 1 --dataset coco --out /workspace/arijit_ug/sushil/zsd/checkpoints/ab_st_final/coco_65_15_wgan_modeSeek_seen_cycSeenUnseen_tripletSeenUnseen_varMargin_try6/coco_65_15_wgan_modeSeek_seen_cycSeenUnseen_tripletSeenUnseen_varMargin_try6_zsd_result.pkl --zsd --syn_weights /workspace/arijit_ug/sushil/zsd/checkpoints/ab_st_final/coco_65_15_wgan_modeSeek_seen_cycSeenUnseen_tripletSeenUnseen_varMargin_try6/classifier_best_latest.pth

NOTE: Change --zsd flag to ---gzsd for evaluation in the generalized ZSD setting. Change directory names accordingly. The classifier weights required in the evaluation step are given for VOC and MSCOCO.

:sweat_smile: Hard-coded argument: For GZSD evaluation, change the default 21 (for VOC) in this line to 81 if you want to test with MSCOCO.

For inference on a custom image, first put it inside the folder custom data. I have kept a few as examples. Obtain the model results using:

cd mmdetection
sh test_zsd_single_img.sh

It will follow Generalized ZSD for inference. Finally, to visualize the bounding boxes for detection, run:

python show_results_single_img.py

7. :trophy: Results

Log files are also uploaded for ZSD and GZSD.

:gift: Citation

If you use our work for your research, kindly star :star: our repository and consider citing our work using the following BibTex:

@inproceedings{Sarma_2022_BMVC,
author    = {Sandipan Sarma and SUSHIL KUMAR and Arijit Sur},
title     = {Resolving Semantic Confusions for Improved Zero-Shot Detection},
booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},
publisher = {{BMVA} Press},
year      = {2022},
url       = {https://bmvc2022.mpi-inf.mpg.de/0347.pdf}
}

:scroll: References

[1] Shafin Rahman, Salman Khan, and Nick Barnes. Polarity loss for zero-shot object detection. arXiv preprint arXiv:1811.08982, 2018.

[2] Berkan Demirel, Ramazan Gokberk Cinbis, and Nazli Ikizler-Cinbis. Zero-shot object detection by hybrid region embedding. In BMVC, 2018.

[3] Yongqin Xian, Christoph H Lampert, Bernt Schiele, and Zeynep Akata. Zero-shot learning—a comprehensive evaluation of the good, the bad and the ugly. IEEE transactions on pattern analysis and machine intelligence, 41(9):2251–2265, 2018.