This repo is
(1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks,
(2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679).
(3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf).
# DOT: A Distillation-Oriented Trainer
### Framework
### Main Benchmark Results
On CIFAR-100:
| Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|
| KD | 73.33 | 72.98 | 74.45 |
| **KD+DOT** | **75.12** | **73.77** | **75.55** |
On Tiny-ImageNet:
| Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|
| KD | 58.35 | 62.26 |
| **KD+DOT** | **64.01** | **65.75** |
On ImageNet:
| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1|
|:---------------:|:-----------------:|:-----------------:|
| KD | 71.03 | 70.50 |
| **KD+DOT** | **71.72** | **73.09** |
# Decoupled Knowledge Distillation
### Framework & Performance
### Main Benchmark Results
On CIFAR-100:
| Teacher
Student |ResNet56
ResNet20|ResNet110
ResNet32| ResNet32x4
ResNet8x4| WRN-40-2
WRN-16-2| WRN-40-2
WRN-40-1 | VGG13
VGG8|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:|
| KD | 70.66 | 73.08 | 73.33 | 74.92 | 73.54 | 72.98 |
| **DKD** | **71.97** | **74.11** | **76.32** | **76.23** | **74.81** | **74.68** |
| Teacher
Student |ResNet32x4
ShuffleNet-V1|WRN-40-2
ShuffleNet-V1| VGG13
MobileNet-V2| ResNet50
MobileNet-V2| ResNet32x4
MobileNet-V2|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|
| KD | 74.07 | 74.83 | 67.37 | 67.35 | 74.45 |
| **DKD** | **76.45** | **76.70** | **69.71** | **70.35** | **77.07** |
On ImageNet:
| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1|
|:---------------:|:-----------------:|:-----------------:|
| KD | 71.03 | 70.50 |
| **DKD** | **71.70** | **72.05** |
# MDistiller
### Introduction
MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO:
|Method|Paper Link|CIFAR-100|ImageNet|MS-COCO|
|:---:|:---:|:---:|:---:|:---:|
|KD|
|✓|✓| |
|FitNet| |✓| | |
|AT| |✓|✓| |
|NST| |✓| | |
|PKT| |✓| | |
|KDSVD| |✓| | |
|OFD| |✓|✓| |
|RKD| |✓| | |
|VID| |✓| | |
|SP| |✓| | |
|CRD| |✓|✓| |
|ReviewKD| |✓|✓|✓|
|DKD| |✓|✓|✓|
### Installation
Environments:
- Python 3.6
- PyTorch 1.9.0
- torchvision 0.10.0
Install the package:
```
sudo pip3 install -r requirements.txt
sudo python3 setup.py develop
```
### Getting started
0. Wandb as the logger
- The registeration: .
- If you don't want wandb as your logger, set `CFG.LOG.WANDB` as `False` at `mdistiller/engine/cfg.py`.
1. Evaluation
- You can evaluate the performance of our models or models trained by yourself.
- Our models are at , please download the checkpoints to `./download_ckpts`
- If test the models on ImageNet, please download the dataset at and put them to `./data/imagenet`
```bash
# evaluate teachers
python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet
# evaluate students
python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100
python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet
python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints
```
2. Training on CIFAR-100
- Download the `cifar_teachers.tar` at and untar it to `./download_ckpts` via `tar xvf cifar_teachers.tar`.
```bash
# for instance, our DKD method.
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml
# you can also change settings at command line
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1
```
3. Training on ImageNet
- Download the dataset at and put them to `./data/imagenet`
```bash
# for instance, our DKD method.
python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml
```
4. Training on MS-COCO
- see [detection.md](detection/README.md)
5. Extension: Visualizations
- Jupyter notebooks: [tsne](tools/visualizations/tsne.ipynb) and [correlation_matrices](tools/visualizations/correlation.ipynb)
### Custom Distillation Method
1. create a python file at `mdistiller/distillers/` and define the distiller
```python
from ._base import Distiller
class MyDistiller(Distiller):
def __init__(self, student, teacher, cfg):
super(MyDistiller, self).__init__(student, teacher)
self.hyper1 = cfg.MyDistiller.hyper1
...
def forward_train(self, image, target, **kwargs):
# return the output logits and a Dict of losses
...
# rewrite the get_learnable_parameters function if there are more nn modules for distillation.
# rewrite the get_extra_parameters if you want to obtain the extra cost.
...
```
2. regist the distiller in `distiller_dict` at `mdistiller/distillers/__init__.py`
3. regist the corresponding hyper-parameters at `mdistiller/engines/cfg.py`
4. create a new config file and test it.
# Citation
If this repo is helpful for your research, please consider citing the paper:
```BibTeX
@article{zhao2022dkd,
title={Decoupled Knowledge Distillation},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
@article{zhao2023dot,
title={DOT: A Distillation-Oriented Trainer},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
journal={arXiv preprint arXiv:2307.08436},
year={2023}
}
```
# License
MDistiller is released under the MIT license. See [LICENSE](LICENSE) for details.
# Acknowledgement
- Thanks for CRD and ReviewKD. We build this library based on the [CRD's codebase](https://github.com/HobbitLong/RepDistiller) and the [ReviewKD's codebase](https://github.com/dvlab-research/ReviewKD).
- Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology.
- Thanks Xin Jin for the discussion about DKD.