youthhoo / AFA_For_Few_shot_learning

19 stars 1 forks source link

AFA_For_Few-Shot_Learning


Please cite our paper if you find the code or dataset useful for your research.

Adversarial Feature Augmentation for Cross-domain Few-shot Classification European Conference on Computer Vision (ECCV), 2022

@inproceedings{hu2022adversarial,
  title={Adversarial Feature Augmentation for Cross-domain Few-Shot Classification},
  author={Hu, Yanxu and Ma, Andy J},
  booktitle={European Conference on Computer Vision},
  pages={20--37},
  year={2022},
  organization={Springer}
}


Yanxu Hu, Andy Ma

ECCV 2022

Citation

If you use this code for your research, please cite our paper

Dependencies

Datasets:

Refers to CDFSL-ATA(https://github.com/Haoqing-Wang/CDFSL-ATA)

Train

1. train the baseline

python train.py --model ResNet10 --method GNN --n_shot 5 --name GNN_5s --train_aug
python train.py --model ResNet10 --method TPN --n_shot 5 --name TPN --train_aug

2.Train meta-learning models with feature-wise transformations.

python train_FT.py --model ResNet10 --method GNN --n_shot 5 --name GNN_FWT_5s --train_aug
python train_FT.py --model ResNet10 --method TPN --n_shot 5 --name TPN_FWT --train_aug

3.Explanation-guided train meta-learning models.

python train.py --model ResNet10 --method GNNLRP --n_shot 5 --name GNN_LRP_5s --train_aug
python train.py --model ResNet10 --method RelationNetLRP --n_shot 5 --name RelationNet_LRP --train_aug

4.Train meta-learning models with Adversarial Task Augmentation.

python train_ATA.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --prob 0.5 --n_shot 5 --name GNN_ATA_5s --train_aug
python train_ATA.py --model ResNet10 --method TPN --max_lr 20. --T_max 5 --prob 0.6 --n_shot 5 --name TPN_ATA --train_aug

5.train with AFA

python train_ND.py --model ResNet10 --method GNN --n_shot 5 --name GNN_ND --train_aug
python train_ND.py --model ResNet10 --method TPN --n_shot 5 --name TPN_ND --train_aug

6.Ablation experiments

python train_NND.py --model ResNet10 --method GNN --n_shot 5 --name GNN_ND --train_aug # 'worst-case feature distribution'
python train_nonlin.py --model ResNet10 --method GNN --n_shot 5 --name GNN_ND --train_aug #non-linear transformation

Evaluation and Fine-tuning

1.Test the trained model on the unseen domains.

python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_ND
python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_ATA_5s

2.Fine-tuning with linear classifier.

To get the results of traditional pre-training and fine-tuning, run

python finetune.py --dataset cub --n_shot 5 --finetune_epoch 50 --model ResNet10 --name model_name

3.Fine-tuning the meta-learning models.

python finetune_ml.py --dataset cub --method GNN --n_shot 5 --finetune_epoch 50 --model ResNet10 --name model_name

Note