icoz69 / DeepEMD

Code for paper "DeepEMD: Few-Shot Image Classification with Differentiable Earth Mover's Distance and Structured Classifiers", CVPR2020
MIT License
577 stars 82 forks source link

DeepEMD: Differentiable Earth Mover's Distance for Few-Shot Learning

PyTorch implementation of

DeepEMD: Few-Shot Image Classification with Differentiable Earth Mover's Distance and Structured Classifiers" (CVPR 2020 oral (oral video) )

and

"DeepEMD v2: Differentiable Earth Mover's Distance for Few-Shot Learning" (TPAMI Extension).

DeepEMD achieves new state-of-the-art performance on five few-shot learning benchmarks with significant advantages (up to 7%). The result is obtained without using any extra data for training or testing (tranductive setting).

Check few-shot classification leaderboard.

If you use the code in this repo for your work, please cite the following bib entries:

@InProceedings{Zhang_2020_CVPR,
author = {Zhang, Chi and Cai, Yujun and Lin, Guosheng and Shen, Chunhua},
title = {DeepEMD: Few-Shot Image Classification With Differentiable Earth Mover's Distance and Structured Classifiers},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2020}
}

and

@misc{zhang2020deepemdv2,
    title={DeepEMD: Differentiable Earth Mover's Distance for Few-Shot Learning},
    author={Chi Zhang and Yujun Cai and Guosheng Lin and Chunhua Shen},
    year={2020},
    eprint={2003.06777v3},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

If you have any question regarding the paper, please send a email to chi007[at]e[dot]ntu[dot]edu[dot]sg.

Abstract

Deep learning has proved to be very effective in learning with a large amount of labelled data. Few-shot learning in contrast attempts to learn with only a few labelled data. In this work, we develop methods for few-shot image classification from a new perspective of optimal matching between image regions. We employ the Earth Mover's Distance (EMD) as a metric to compute a structural distance between dense image representations to determine image relevance. The EMD generates the optimal matching flows between structural elements that have the minimum matching cost, which is used to calculate the image distance for classification. To generate the important weights of elements in the EMD formulation, we design a cross-reference mechanism, which can effectively alleviate the adverse impact caused by the cluttered background and large intra-class appearance variations. To handle k-shot classification, we propose to learn a structured fully connected layer that can directly classify dense image representations with the proposed EMD. Based on the implicit function theorem, the EMD can be inserted as a layer into the network for end-to-end training. Our extensive experiments validate the effectiveness of our algorithm which outperforms state-of-the-art methods by a significant margin on four widely used few-shot classification benchmarks, namely, miniImageNet, tieredImageNet, Fewshot-CIFAR100 (FC100) and Caltech-UCSD Birds-200-2011 (CUB).

Few-shot classification Results

Experimental results on few-shot learning datasets with ResNet-12 backbone. We report average results with 5,000 randomly sampled episodes for 1-shot evaluation and 600 episodes for k-shot evaluation

MiniImageNet Dataset

Setups 1-Shot 5-Way 5-Shot 5-Way
Previous SOTA 64.12 80.51
DeepEMD-FCN 66.50 82.41
DeepEMD-Grid 67.83 83.14
DeepEMD-Sampling 68.77 84.13

TieredImageNet Dataset

Setups 1-Shot 5-Way 5-Shot 5-Way
Previous SOTA 68.50 84.28
DeepEMD-FCN 72.65 86.03
DeepEMD-Grid 73.13 87.08
DeepEMD-Sampling 74.29 86.98

Prerequisites

The following packages are required to run the scripts:

MiniImageNet Dataset

It contains 100 classes with 600 images in each class, which are built upon the ImageNet dataset. The 100 classes are divided into 64, 16, 20 for meta-training, meta-validation and meta-testing, respectively.

TieredImageNet Dataset

TieredImageNet is also a subset of ImageNet, which includes 608 classes from 34 super-classes. Compared with miniImageNet, the splits of meta-training(20), meta-validation(6) and meta-testing(8) are set according to the super-classes to enlarge the domain difference between training and testing phase. The dataset also include more images for training and evaluation (779,165 images in total).

CUB Dataset

CUB was originally proposed for fine-grained bird classification, which contains 11,788 images from 200 classes. We follow the splits in FEAT that 200 classes are divided into 100, 50 and 50 for meta-training, meta-validation and meta-testing, respectively.

FC100 Dataset

FC100 is a few-shot classification dataset built on CIFAR100. We follow the split division proposed in TADAM, where 36 super-classes were divided into 12 (including 60 classes), 4 (including 20 classes), 4 (including 20 classes), for meta-training, meta-validation and meta-testing, respectively, and each class contains 600 images.

CIFAR-FS dataset (not in paper)

CIFAR-FS was also built upon CIFAR100,proposed in here. It contains 64, 16, 20 classes for training, validation and testing.

Important Arguments

We list some important arguments of our networks.

Model Selection Arguments

Solver Selection Arguments

DeepEMD-FCN Related Arguments

DeepEMD-Grid Related Arguments

DeepEMD-Sampling Related Arguments

SFC Related Arguments

Some general tips

Get started. The training of our model has two stages, the model pre-training stage (train_pretrain.py) and the episodic meta-training stage (train_meta.py). You may also directly test the trained models by running eval.py. Before running these scripts, please set the dataset directory (-data_dir) and pre-trained model directory (-pretrain_dir) in the arguments or directly change the default directories in the code.

Solver selection. We provide two solvers to solve the LP problem in the EMD layer, OpenCV and QPTH. OpenCV is much faster than QPTH, therefore you can use OpenCV for validation and QPTH for training. You may also use OpenCV for training, which is much faster and saves memory, but this omits the graidents through the constraints and compromises performance.

About GPU memory. There are many arguments that influence the GPU momory. You may choose to adjust these arguments to make a blance between GPU memory and performance.

For DeepEMD-Sampling and DeepEMD-Grid, you may choose to train with a small number of patches but test with a big number of patches to achieve improved performance.

Testing scripts for DeepEMD

Test DeepEMD-FCN with a trained model for 1-shot 5-way task on the miniImageNet dataset:

$ python eval.py  -deepemd fcn -gpu 0,1,2,3

Test DeepEMD-FCN with a trained model for 5-shot 5-way task on the miniImageNet dataset:

$ python eval.py  -deepemd fcn -shot 5 -test_episode 600 -gpu 0,1,2,3

Test DeepEMD-Gird-Pyramid (2,3) with a trained model for 1-shot 5-way task on the miniImageNet dataset:

$ python eval.py  -deepemd grid -patch_list 2,3  -gpu 0,1,2,3

Test DeepEMD-Sampling (9-patch) with a trained model for 1-shot 5-way task on the miniImageNet dataset:

$ python eval.py  -deepemd sampling -num_patch 9   -gpu 0,1,2,3

Training scripts for DeepEMD

Pre-train the models on the miniImagenet

$ python train_pretrain.py -dataset miniimagenet -gpu 0,1,2,3

Train DeepEMD-FCN with a pre-trained model for 1-shot 5-way task on the miniImageNet dataset:

#use opencv solver (about 8GB memory)
$ python train_meta.py -deepemd fcn -shot 1 -way 5 -solver opencv -gpu 0,1,2,3

#use QPTH solver (about 32GB memory)
$ python train_meta.py -deepemd fcn -shot 1 -way 5 -solver qpth -gpu 0,1,2,3

Train DeepEMD-Gird-Pyramid (2,3) with a pre-trained model for 1-shot 5-way task on the miniImageNet dataset:

#use opencv solver (about 45GB memory)
$ python train_meta.py  -deepemd grid -patch_list 2,3 -shot 1 -way 5 -solver opencv -gpu 0,1,2,3

Train DeepEMD-Sampling (9 patchs) with a pre-trained model for 1-shot 5-way task on the miniImageNet dataset:

#use opencv solver (about 32GB memory)
$ python train_meta.py  -deepemd sampling -patch_list 9 -shot 1 -way 5 -solver opencv -gpu 0,1,2,3

Download Models

Pre-trained Models (or run bash download_pretrain_model.sh)

Meta-trained Models (or run bash download_trained_model.sh)

Acknowledgment

Our project references the codes in the following repos.