Pushing the Limits of Simple Pipelines for Few-Shot Learning: External Data and Fine-Tuning Make a Difference
Shell Xu Hu, Da Li, Jan Stühmer, Minyoung Kim and Timothy Hospedales
[Project page] [blog] [Arxiv] [Gradio demo]
If you find our project helpful, please consider cite our paper:
@inproceedings{hu2022pmf,
author = {Hu, Shell Xu
and Li, Da
and St\"uhmer, Jan
and Kim, Minyoung
and Hospedales, Timothy M.},
title = {Pushing the Limits of Simple Pipelines for Few-Shot Learning:
External Data and Fine-Tuning Make a Difference},
booktitle = {CVPR},
year = {2022}
}
08/11/2022 We released the checkpoints meta-trained on Meta-Dataset with pre-trained DINO.
18/04/2022 We released a Gradio demo on Huggingface Space for few-shot learning where the support set is created by text-to-image retrieval, making it a cheap version of CLIP-like zero-shot learning.
08/12/2021 An early version of P>M>F won 2nd place in NeurIPS-MetaDL-2021 with Henry Gouk.
We release three checkpoints meta-trained on Meta-Dataset with pre-trained DINO for reproducibility.
pip install -r requirements.txt
The code was tested with Python 3.8.1 and Pytorch >= 1.7.0.
We provide dataset classes and DDP dataloaders for CIFAR-FS, Mini-ImageNet and Meta-Dataset, and also adapted the CDFSL datasets to our pipeline.
The overall structure of the datasets
folder is the following:
datasets/
├── cdfsl/ # CDFSL datasets
├── episodic_dataset.py # CIFAR-FS & Mini-ImageNet
├── __init__.py # summary & interface
├── meta_dataset/ # code adapted from Google's meta-dataset and pytorch-meta-dataset
├── meta_h5_dataset.py # meta-dataset class to sample episodes and fetch data from h5 files
├── meta_val_dataset.py # meta-dataset class for validation with fixed val episodes
We unify the dataset definition, distributed sampler and randomness control in datasets/__init__.py, but in general the only functions you may need to pay attention are
cd scripts
sh download_cifarfs.sh
sh download_miniimagenet.sh
To use these two datasets, set --dataset cifar_fs
or --dataset mini_imagenet
.
We implement a pytorch version of Meta-Dataset.
Our implementation is based on mboudiaf's pytorch-meta-dataset.
The major change is we replace the tfrecords
to h5
files to largely reduce IO latency.
This also enables efficient DDP data-loading otherwise tfrecords leads to streaming dataset which is less easy to DDP.
The dataset has 10 domains, 4000+ classes. Episodes are formed in various-way-various-shot fashion, where an episode can have 900+ images.
The images are stored class-wise in h5 files (converted from the origianl tfrecords, one for each class).
To train and test on this dataset, set --dataset meta_dataset
and --data_path /path/to/meta_dataset
.
To download the h5 files,
git clone https://huggingface.co/datasets/hushell/meta_dataset_h5
You can also generate h5 files by yourself following these steps:
Download 10 domains (e.g., cu_birds) listed in google-research/meta-dataset, which will create tfrecords
files for each class and dataset_spec.json
for each domain.
Once done, you should get a folder meta-dataset/tf_records
with 11 sub-folders of *.tfrecords
files (including ilsvrc_2012 and ilsvrc_2012_v2).
Generate the index files of tfrecords with existing tool:
export RECORDS='path/to/tfrecords'
for source in omniglot aircraft cu_birds dtd quickdraw vgg_flower traffic_sign mscoco ilsvrc_2012; do \
source_path=${RECORDS}/${source} ;\
find ${source_path} -name '*.tfrecords' -type f -exec sh -c 'python3 datasets/meta_dataset/tfrecord/tools/tfrecord2idx.py $2 ${2%.tfrecords}.index' sh ${source_path} {} \; ;\
done ;\
This command will create for each tfrecords file an index file with the same name. E.g., 0.tfrecords -> 0.index
.
Convert tfrecords
to h5
by calling
python scripts/convert_tfrecord_to_h5.py /path/to/meta-dataset/tf_records
This command will create for each tfrecords file an h5 file with the same name. E.g., 0.tfrecords -> 0.h5
.
Generate 120 validation tasks per domain on a set of reserved classes by calling
python scripts/generate_val_episodes_to_h5.py --data-path /path/to/meta-dataset/tf_records
This goes into a single h5
file for each domain. E.g., cu_birds/val_ep120_img128.h5
. This is to remove randomness in validation.
The purpose of this benchmark is to evaluate how model trained on Mini-ImageNet (source domain) performs on cross-domain meta-test tasks.
So we only need to download the target domains, and extract the files into ./data/
.
You'll need to have these 4 sub-folders:
./data/ChestX
./data/CropDiseases
./data/EuroSAT/2750
./data/ISIC
Check get_bscd_loader() for the data loader details.
We support multiple pretrained foundation models. E.g.,
DINO ViT-base
, DINO ViT-small
, DINO ResNet50
, BeiT ViT-base
, CLIP ViT-base
, CLIP ResNet50
and so on.
For options of args.arch
, please check get_backbone().
It is recommended to run on a single GPU first by specifying args.device = cuda:i
, where i is the GPU id to be used.
We use args.nSupport
to set the number of shots. For example, 5-way-5-shot training command of CIFAR-FS writes as
python main.py --output outputs/your_experiment_name --dataset cifar_fs --epoch 100 --lr 5e-5 --arch dino_small_patch16 --device cuda:0 --nSupport 5 --fp16
Because at least one episode has to be hosted on the GPU, the program is quite memory hungry. Mixed precision (--fp16
) is recommended.
Since each class is stored in a h5 file, training will open many files. The following command is required before launching the code:
ulimit -n 100000 # may need to check `ulimit -Hn` first to know the hard limit
For various-way-various-shot training (#ways = 5-50, max #query = 10, max #supp = 500, max #supp per class = 100), the following command yields a P(DINO ViT-small) -> M(ProtoNet) pipeline to update backbone:
python main.py --output outputs/your_experiment_name --dataset meta_dataset --data-path /path/to/meta-dataset/ --num_workers 4 --base_sources aircraft cu_birds dtd ilsvrc_2012 omniglot fungi vgg_flower quickdraw --epochs 100 --lr 5e-4 --arch dino_small_patch16 --dist-eval --device cuda:0 --fp16
The minimum GPU memory is 24GB. The logging file outputs/your_experiment_name/log.txt
can be used to monitor model performance (as you can check it remotely).
Just replace --base_sources ...
by --base_sources ilsvrc_2012
.
First, setting up the following environmental variables:
export RANK=0 # machine id
export WORLD_SIZE=1 # total number of machines
For example, if you got 8 GPUs, run this command will accumulate gradients from 8 parallel episodes:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output outputs/your_experiment_name --dataset meta_dataset --data-path /path/to/h5/files/ --num_workers 4 --base_sources aircraft cu_birds dtd ilsvrc_2012 omniglot fungi vgg_flower quickdraw --epochs 100 --lr 5e-4 --arch dino_small_patch16 --dist-eval --fp16
Copy the same command for training, which can be found in outputs/your_experiment_name/log.txt
(should be the first line or search keyword main.py),
and add --eval
.
When domain shift exists between meta-training and meta-testing, we enable different model deployment modes: vanilla
(ProtoNet classification) and finetune
(the backbone will be updated on support set). Check get_model() for the actual implementation.
For finetune
, a few hyper-parameters are introduced: args.ada_steps
, args.ada_lr
, args.aug_prob
, args.aug_types
, among which args.ada_lr
is the more sensitive one and requires validation (e.g., --ada_lr 0.001
). We also recommend to play with args.ada_steps
(e.g., --ada_steps 50
). The good news is that empirically we find good performance can be achieved by domain-wise hyper-parameter search. As recommended in our paper, 3-5 episodes with labeled query set is sufficient to tell what the best hyper-parameters are, which makes model deployment practical for a novel domain with a few annotated examples.
A meta-testing command example for Meta-Dataset with fine-tuning is
python -m torch.distributed.launch --nproc_per_node=8 --use_env test_meta_dataset.py --data-path /path/to/meta_dataset/ --dataset meta_dataset --arch dino_small_patch16 --deploy finetune --output outputs/your_experiment_name --resume outputs/your_experiment_name/best.pth --dist-eval --ada_steps 100 --ada_lr 0.0001 --aug_prob 0.9 --aug_types color translation
To meta-test without fine-tuning, just replace --deploy finetune
with --deploy vanilla
.
If you don't want to enable DDP for testing, just replacing --dist-eval
by --device cuda:0
and remove torch.distributed.launch
part.
By default, all 10 domains will be evaluated, but you may meta-test only a subset by specifying which domains should be executed with --test_sources
. Check utils/args.py:48 for domain names.
Below are the results on Meta-Dataset test-set for DINO checkpoints:
Method | ILSVRC (test) | Omniglot | Aircraft | Birds | Textures | QuickDraw | Fungi | VGG Flower | Traffic signs | MSCOCO |
---|---|---|---|---|---|---|---|---|---|---|
md_full_128x128_dinosmall_fp16_lr5e-5 | 73.52±0.80 (lr=0.0001) | 92.17±0.57 (lr=0.0001) | 89.49±0.52 (lr=0.001) | 91.04±0.37 (lr=0.0001) | 85.73±0.62 (lr=0.001) | 79.43±0.67 (lr=0.0001) | 74.99±0.94 (lr=0) | 95.30±0.44 (lr=0.001) | 89.85±0.76 (lr=0.01) | 59.69±1.02 (lr=0.001) |
md_inet_128x128_dinosmall_fp16_lr2e-4 | 75.51±0.72 (lr=0.001) | 82.81±1.10 (lr=0.01) | 78.38±1.09 (lr=0.01) | 85.18±0.77 (lr=0.001) | 86.95±0.60 (lr=0.001) | 74.47±0.83 (lr=0.01) | 55.16±1.09 (lr=0) | 94.66±0.48 (lr=0) | 90.04±0.81 (lr=0.01) | 62.60±0.96 (lr=0.001) |
Meta-testing CDFSL is almost the same as described in previous section for Meta-Dataset. However, we create another script test_bscdfsl.py to fit CDFSL's original data loaders.
An meta-testing command example for CDFSL with fine-tuning is
python test_bscdfsl.py --test_n_way 5 --n_shot 5 --device cuda:0 --arch dino_small_patch16 --deploy finetune --output outputs/your_experiment_name --resume outputs/your_experiment_name/best.pth --ada_steps 100 --ada_lr 0.0001 --aug_prob 0.9 --aug_types color translation
Changing --n_shot
to 20 or 50 to evaluate other settings.