A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"
Authors: Gongfan Fang, Jie Song, Xinchao Wang, Chengchao Shen, Xingen Wang, Mingli Song
CMI (this work) | DeepInv |
ZSKT | DFQ |
Method | resnet-34 resnet-18 |
vgg-11 resnet-18 |
wrn-40-2 wrn-16-1 |
wrn-40-2 wrn-40-1 |
wrn-40-2 wrn-16-2 |
---|---|---|---|---|---|
T. Scratch | 95.70 | 92.25 | 94.87 | 94.87 | 94.87 |
S. Scratch | 95.20 | 95.20 | 91.12 | 93.94 | 93.95 |
DAFL | 92.22 | 81.10 | 65.71 | 81.33 | 81.55 |
ZSKT | 93.32 | 89.46 | 83.74 | 86.07 | 89.66 |
DeepInv | 93.26 | 90.36 | 83.04 | 86.85 | 89.72 |
DFQ | 94.61 | 90.84 | 86.14 | 91.69 | 92.01 |
CMI | 94.84 | 91.13 | 90.01 | 92.78 | 92.52 |
Method | resnet-34 resnet-18 |
vgg-11 resnet-18 |
wrn-40-2 wrn-16-1 |
wrn-40-2 wrn-40-1 |
wrn-40-2 wrn-16-2 |
---|---|---|---|---|---|
T. Scratch | 78.05 | 71.32 | 75.83 | 75.83 | 75.83 |
S. Scratch | 77.10 | 77.01 | 65.31 | 72.19 | 73.56 |
DAFL | 74.47 | 57.29 | 22.50 | 34.66 | 40.00 |
ZSKT | 67.74 | 34.72 | 30.15 | 29.73 | 28.44 |
DeepInv | 61.32 | 54.13 | 53.77 | 61.33 | 61.34 |
DFQ | 77.01 | 68.32 | 54.77 | 62.92 | 59.01 |
CMI | 77.04 | 70.56 | 57.91 | 68.88 | 68.75 |
Results will be saved as checkpoints/datafree-cmi/synthetic-cmi_for_vis.png
bash scripts/cmi/cmi_cifar10_for_vis.sh
Note: This repo was refactored from our experimental code and is still under development. I'm struggling to find the appropriate hyperparams for every methods (°ー°〃). So far, we only provide the hyperparameters to reproduce CIFAR-10 results for wrn-40-2 => wrn-16-1. You may need to tune the hyper-parameters for other models and datasets. More resources will be uploaded in the future update.
To reproduce our results, please download pre-trained teacher models from Dropbox-Models (266 MB) and extract them as checkpoints/pretrained
. Also a pre-inverted data set with ~50k samples is available for wrn-40-2 teacher on CIFAR-10. You can download it from Dropbox-Data-Preinverted (133 MB) and extract them to run/cmi-preinverted-wrn402/
.
Non-adversarial CMI: you can train a student model on inverted data directly. It should reach the accuracy of ~87.38% on CIFAR-10 as reported in Figure 3.
bash scripts/cmi/nonadv_cmi_cifar10_wrn402_wrn161.sh
Adversarial CMI: or you can apply the adversarial distillation based on the pre-inverted data from the above-mentioned non-adversarial process. About 10k (256x40) new samples will be generated to improve the student, which should achieve the accuracy of ~90.01% on CIFAR-10 as reported in Table 1.
bash scripts/cmi/adv_cmi_cifar10_wrn402_wrn161.sh
Scratch CMI: It is OK to run the cmi algorithm wihout any pre-inverted data, but the student may overfit to some early data. In this case, the student should achieve the accuracy of ~88.82% on CIFAR-10, slightly worse than adversarial CMI (90.01%).
bash scripts/cmi/scratch_cmi_cifar10_wrn402_wrn161.sh
python train_scratch.py --model wrn40_2 --dataset cifar10 --batch-size 256 --lr 0.1 --epoch 200 --gpu 0
# KD with original training data (beta>0 to use hard targets)
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar10 --beta 0.1 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0
# KD with unlabeled data
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar100 --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0
# KD with unlabeled data from a specified folder
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set run/cmi --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0
bash scripts/xxx/xxx.sh # e.g. scripts/zskt/zskt_cifar10_wrn402_wrn161.sh
Hyper-parameters used by different methods:
Method | adv | bn | oh | balance | act | cr | GAN | Example |
---|---|---|---|---|---|---|---|---|
DAFL | - | - | √ | √ | √ | - | √ | scripts/dafl_cifar10.sh |
ZSKT | √ | - | - | - | - | - | √ | scripts/zskt_cifar10.sh |
DeepInv | √ | √ | √ | - | - | - | - | scripts/deepinv_cifar10.sh |
DFQ | √ | √ | √ | √ | - | - | √ | scripts/dfq_cifar10.sh |
CMI | √ | √ | √ | - | - | √ | √ | scripts/cmi_cifar10_scratch.sh |
You can register your models and datasets in registry.py by modifying NORMALIZE_DICT
, MODEL_DICT
and get_dataset
. Then you can run the above commands to train your own models. As DAFL requires intermediate features from the penultimate layer, your model should accept an return_features=True
parameter and return a (logits, features) tuple for DAFL.
Your algorithms should inherent datafree.synthesis.BaseSynthesizer
to implement two interfaces: 1) BaseSynthesizer.synthesize
takes several steps to craft new samples and return an image dict for visualization; 2) BaseSynthesizer.sample
fetches a batch of training data for KD.
If you found this work useful for your research, please cite our paper:
@article{fang2021contrastive,
title={Contrastive Model Inversion for Data-Free Knowledge Distillation},
author={Fang, Gongfan and Song, Jie and Wang, Xinchao and Shen, Chengchao and Wang, Xingen and Song, Mingli},
journal={arXiv preprint arXiv:2105.08584},
year={2021}
}