AhmedAyad89 / Consitent-Prototypical-Networks-Semi-Supervised-Few-Shot-Learning

Other
19 stars 4 forks source link

Prototypical Random Walk Networks(PRWN)

Code for paper Semi-Supervised Few-Shot Learning with Prototypical Random Walks [Arxiv][Poster][Slides]

Overview

In this paper we design a semi-supervised loss to leverage unlabelled data in the few-shot setting. Our base model is a prototypical network, and we add the Prototypical Random Walk loss in order to leverage the unlabeled data during the episodic meta-training. Our loss is designed to train Prototypical Networks to produce embeddings where points of the each class form a tight cluster around the class prototype. We find that PRWN outperform the prior state-of-the-art on all 9 benchamrak tests we run, in a variety of semi-supervised few-shot learning settings. We often see dramatic improvements over prior SOTA, for example, PRWN obtains 69.65% in one test comapred to 64.59 for the prior SOTA. Remarkably, PRWN also outperforms the fully supervised prototypical network in one test, obtaining 50.89% to 49.4% for the baseline.

Dependencies

Our code is tested on Ubuntu 16.04.

Setup

First, designate a folder to be your data root:

export DATA_ROOT={DATA_ROOT}

Then, set up the datasets following the instructions in the subsections.

Omniglot

[Google Drive] (9.3 MB)

# Download and place "omniglot.tar.gz" in "$DATA_ROOT/omniglot".
mkdir -p $DATA_ROOT/omniglot
cd $DATA_ROOT/omniglot
mv ~/Downloads/omniglot.tar.gz .
tar -xzvf omniglot.tar.gz
rm -f omniglot.tar.gz

miniImageNet

[Google Drive] (1.1 GB)

# Download and place "mini-imagenet.tar.gz" in "$DATA_ROOT/mini-imagenet".
mkdir -p $DATA_ROOT/mini-imagenet
cd $DATA_ROOT/mini-imagenet
mv ~/Downloads/mini-imagenet.tar.gz .
tar -xzvf mini-imagenet.tar.gz
rm -f mini-imagenet.tar.gz

Core Experiments

Please run the following scripts to reproduce the core experiments.

#First place the data_root folder inside the provided code folder. 

# To train a model.
python run_exp.py --data_root $DATA_ROOT             \
                  --dataset {DATASET}                \
                  --label_ratio {LABEL_RATIO}        \
                  --model {MODEL}                    \
                  --results {SAVE_CKPT_FOLDER}       \
                  [--disable_distractor]             \
                  [--nshot]                          \
                  [--nclasses_train]                 \

# To test a model.
python run_exp.py --data_root $DATA_ROOT             \
                  --dataset {DATASET}                \
                  --label_ratio {LABEL_RATIO}        \
                  --model {MODEL}                    \
                  --results {SAVE_CKPT_FOLDER}       \
                  --eval --pretrain {MODEL_ID}       \
                  [--num_unlabel {NUM_UNLABEL}]      \
                  [--num_test {NUM_TEST}]            \
                  [--disable_distractor]             \
                  [--use_test]

Simple Baselines for Few-Shot Classification

Please run the following script to reproduce a suite of baseline results.

python run_baseline_exp.py --data_root $DATA_ROOT    \
                           --dataset {DATASET}

Run SOTA PRWN models

To train/test the state of the art PRWN, and reproduce the results in the paper, set hyperparams as specified in the paper, and run the basic-RW model.

For example, to train a PRWN on 5-shot mini-imagenet:

python run_exp.py --data_root $DATA_ROOT            \
                        --dataset mini-imagenet     \
                        --label_ratio 0.4           \
                        --model basic-RW            \
                        --nshot 5                   \
                        --num_unlabel 10            \
                        [--disable_distractor]      \

To test:

python run_exp.py --data_root $DATA_ROOT            \
                  --dataset mini-imagenet           \
                  --model basic-RW                  \
                  --results {SAVE_CKPT_FOLDER}      \
                  --eval --pretrain {MODEL_ID}      \
                  [--num_unlabel {NUM_UNLABEL}]     \
                  [--num_test {NUM_TEST}]           \
                  [--disable_distractor]            \
                  [--use_test]

To test PRWN+semi-supervised inference:

python run_exp.py --data_root $DATA_ROOT             \
                  --dataset mini-imagenet            \
                  --model kmeans-refine              \
                  --results {SAVE_CKPT_FOLDER}       \
                  --eval --pretrain {MODEL_ID}       \
                  [--num_unlabel {NUM_UNLABEL}]      \
                  [--num_test {NUM_TEST}]            \
                  [--disable_distractor]             \
                  [--use_test]

To test PRWN+semi-supervised inference with the distractor filtering:

python run_exp.py --data_root $DATA_ROOT             \
                  --dataset mini-imagenet            \
                  --model kmeans-filter              \
                  --results {SAVE_CKPT_FOLDER}       \
                  --eval --pretrain {MODEL_ID}       \
                  [--num_unlabel {NUM_UNLABEL}]      \
                  [--num_test {NUM_TEST}]            \
                  [--disable_distractor]             \
                  [--use_test]

Acknowledgements

This code is based on [https://github.com/renmengye/few-shot-ssl-public]. Based on the paper: