TKKim93 / DivMatch

Diversify and Match: A Domain Adaptive Representation Learning Paradigm for Object Detection
http://openaccess.thecvf.com/content_CVPR_2019/papers/Kim_Diversify_and_Match_A_Domain_Adaptive_Representation_Learning_Paradigm_for_CVPR_2019_paper.pdf
MIT License
46 stars 7 forks source link

Diversify and Match

Acknowledgment

The implementation is built on the pytorch implementation of Faster RCNN jwyang/faster-rcnn.pytorch

Preparation

  1. Clone the code and create a folder

    git clone https://github.com/TKKim93/DivMatch.git
    cd faster-rcnn.pytorch && mkdir data
  2. Build the Cython modules

    cd DivMatch/lib
    sh make.sh

Prerequisites

Pretrained Model

You can download pretrained VGG and ResNet101 from jwyang's repository. Default location in my code is './data/pretrained_model/'.

Repository Structure

DivMatch
├── cfgs
├── data
│   ├── pretrained_model
├── datasets
│   ├── clipart
│   │   ├── Annotations
│   │   ├── ImageSets
│   │   ├── JPEGImages
│   ├── clipart_CP
│   ├── clipart_CPR
│   ├── clipart_R
│   ├── comic
│   ├── comic_CP
│   ├── comic_CPR
│   ├── comic_R
│   ├── Pascal
│   ├── watercolor_CP
│   ├── watercolor_CPR
│   ├── watercolor_R
├── lib
├── models (save location)

Example

Diversification stage

Here are the simplest ways to generate shifted domains via CycleGAN. Some of them performs unnecessary computations, thus you may revise the I2I code for faster training.

  1. CP shift

Change line 177 in models/cycle_gan_model.py to

loss_G = self.loss_G_A + self.loss_G_B + self.loss_idt_A + self.loss_idt_B
  1. R shift

Change line 177 in models/cycle_gan_model.py to

loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B
  1. CPR shift

Use the original cyclegan model.

Matching stage

Here is an example of adapting from Pascal VOC to Clipart1k:

  1. You can prepare the Pascal VOC datasets from py-faster-rcnn and the Clipart1k dataset from cross-domain-detection in VOC data format.
  2. Shift the source domain through domain shifter. Basically, I used a residual generator and a patchGAN discriminator. For the short cut, you can download some examples of shifted domains (Link) and put these datasets into data folder.
  3. Train the object detector through MRL for the Pascal -> Clipart1k adaptation task.
    python train.py --dataset clipart --net vgg16 --cuda
  4. Test the model
    python test.py --dataset clipart --net vgg16 --cuda

Downloads