Pytorch implementation for our cross-domain few-shot classification method. With the proposed learned feature-wise transformation layers, we are able to:
Contact: Hung-Yu Tseng (htseng6@ucmerced.edu)
Please cite our paper if you find the code or dataset useful for your research.
Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation
Hung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang
International Conference on Learning Representations (ICLR), 2020 (spotlight)
@inproceedings{crossdomainfewshot,
author = {Tseng, Hung-Yu and Lee, Hsin-Ying and Huang, Jia-Bin and Yang, Ming-Hsuan},
booktitle = {International Conference on Learning Representations},
title = {Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation},
year = {2020}
}
requirements.txt
file we provide to setup the environment via Anaconda.
conda create --name py36 python=3.6
conda install pytorch torchvision -c pytorch
pip3 install -r requirements.txt
Clone this repository:
git clone https://github.com/hytseng0509/CrossDomainFewShot.git
cd CrossDomainFewShot
Download 5 datasets seperately with the following commands.
DATASET_NAME
to: cars
, cub
, miniImagenet
, places
, or plantae
.
cd filelists
python3 process.py DATASET_NAME
cd ..
We adopt baseline++
for MatchingNet, and baseline
from CloserLookFewShot for other metric-based frameworks.
cd output/checkpoints
python3 download_encoder.py
cd ../..
PRETRAIN
to baseline++
or baseline
).
python3 train_baseline.py --method PRETRAIN --dataset miniImagenet --name PRETRAIN --train_aug
Baseline training w/o feature-wise transformations.
METHOD
: metric-based framework matchingnet
, relationnet_softmax
, or gnnnet
.TESTSET
: unseen domain cars
, cub
, places
, or plantae
.
python3 train_baseline.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_ori_METHOD --warmup PRETRAIN --train_aug
Training w/ learning-to-learned feature-wise transformations.
python3 train.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_lft_METHOD --warmup PRETRAIN --train_aug
Test the metric-based framework METHOD
on the unseen domain TESTSET
.
--name
(e.g., --name multi_TESTSET_lft_METHOD
from the above example).
python3 test.py --method METHOD --name NAME --dataset TESTSET
--n_shot
.gnnnet
approach w/ learning-to-learned feature-wise transformations.