megvii-research / mdistiller

The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf
808 stars 123 forks source link
cifar coco computer-vision cvpr2022 deep-learning iccv2023 imagenet knowledge-distillation pytorch
This repo is (1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks, (2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679). (3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf). # DOT: A Distillation-Oriented Trainer ### Framework
### Main Benchmark Results On CIFAR-100: | Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2| |:---------------:|:-----------------:|:-----------------:|:-----------------:| | KD | 73.33 | 72.98 | 74.45 | | **KD+DOT** | **75.12** | **73.77** | **75.55** | On Tiny-ImageNet: | Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2| |:---------------:|:-----------------:|:-----------------:| | KD | 58.35 | 62.26 | | **KD+DOT** | **64.01** | **65.75** | On ImageNet: | Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1| |:---------------:|:-----------------:|:-----------------:| | KD | 71.03 | 70.50 | | **KD+DOT** | **71.72** | **73.09** | # Decoupled Knowledge Distillation ### Framework & Performance
### Main Benchmark Results On CIFAR-100: | Teacher
Student |ResNet56
ResNet20|ResNet110
ResNet32| ResNet32x4
ResNet8x4| WRN-40-2
WRN-16-2| WRN-40-2
WRN-40-1 | VGG13
VGG8| |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:| | KD | 70.66 | 73.08 | 73.33 | 74.92 | 73.54 | 72.98 | | **DKD** | **71.97** | **74.11** | **76.32** | **76.23** | **74.81** | **74.68** | | Teacher
Student |ResNet32x4
ShuffleNet-V1|WRN-40-2
ShuffleNet-V1| VGG13
MobileNet-V2| ResNet50
MobileNet-V2| ResNet32x4
MobileNet-V2| |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:| | KD | 74.07 | 74.83 | 67.37 | 67.35 | 74.45 | | **DKD** | **76.45** | **76.70** | **69.71** | **70.35** | **77.07** | On ImageNet: | Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1| |:---------------:|:-----------------:|:-----------------:| | KD | 71.03 | 70.50 | | **DKD** | **71.70** | **72.05** | # MDistiller ### Introduction MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO: |Method|Paper Link|CIFAR-100|ImageNet|MS-COCO| |:---:|:---:|:---:|:---:|:---:| |KD| |✓|✓| | |FitNet| |✓| | | |AT| |✓|✓| | |NST| |✓| | | |PKT| |✓| | | |KDSVD| |✓| | | |OFD| |✓|✓| | |RKD| |✓| | | |VID| |✓| | | |SP| |✓| | | |CRD| |✓|✓| | |ReviewKD| |✓|✓|✓| |DKD| |✓|✓|✓| ### Installation Environments: - Python 3.6 - PyTorch 1.9.0 - torchvision 0.10.0 Install the package: ``` sudo pip3 install -r requirements.txt sudo python3 setup.py develop ``` ### Getting started 0. Wandb as the logger - The registeration: . - If you don't want wandb as your logger, set `CFG.LOG.WANDB` as `False` at `mdistiller/engine/cfg.py`. 1. Evaluation - You can evaluate the performance of our models or models trained by yourself. - Our models are at , please download the checkpoints to `./download_ckpts` - If test the models on ImageNet, please download the dataset at and put them to `./data/imagenet` ```bash # evaluate teachers python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100 python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet # evaluate students python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100 python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints ``` 2. Training on CIFAR-100 - Download the `cifar_teachers.tar` at and untar it to `./download_ckpts` via `tar xvf cifar_teachers.tar`. ```bash # for instance, our DKD method. python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml # you can also change settings at command line python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1 ``` 3. Training on ImageNet - Download the dataset at and put them to `./data/imagenet` ```bash # for instance, our DKD method. python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml ``` 4. Training on MS-COCO - see [detection.md](detection/README.md) 5. Extension: Visualizations - Jupyter notebooks: [tsne](tools/visualizations/tsne.ipynb) and [correlation_matrices](tools/visualizations/correlation.ipynb) ### Custom Distillation Method 1. create a python file at `mdistiller/distillers/` and define the distiller ```python from ._base import Distiller class MyDistiller(Distiller): def __init__(self, student, teacher, cfg): super(MyDistiller, self).__init__(student, teacher) self.hyper1 = cfg.MyDistiller.hyper1 ... def forward_train(self, image, target, **kwargs): # return the output logits and a Dict of losses ... # rewrite the get_learnable_parameters function if there are more nn modules for distillation. # rewrite the get_extra_parameters if you want to obtain the extra cost. ... ``` 2. regist the distiller in `distiller_dict` at `mdistiller/distillers/__init__.py` 3. regist the corresponding hyper-parameters at `mdistiller/engines/cfg.py` 4. create a new config file and test it. # Citation If this repo is helpful for your research, please consider citing the paper: ```BibTeX @article{zhao2022dkd, title={Decoupled Knowledge Distillation}, author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun}, journal={arXiv preprint arXiv:2203.08679}, year={2022} } @article{zhao2023dot, title={DOT: A Distillation-Oriented Trainer}, author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun}, journal={arXiv preprint arXiv:2307.08436}, year={2023} } ``` # License MDistiller is released under the MIT license. See [LICENSE](LICENSE) for details. # Acknowledgement - Thanks for CRD and ReviewKD. We build this library based on the [CRD's codebase](https://github.com/HobbitLong/RepDistiller) and the [ReviewKD's codebase](https://github.com/dvlab-research/ReviewKD). - Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology. - Thanks Xin Jin for the discussion about DKD.