yoniaflalo / knapsack_pruning

Implementation of knapsack pruning
Apache License 2.0
28 stars 5 forks source link

Knapsack Pruning with Inner Distillation

This code has been adapted from the excellent Ross Wightman repository that we used to train our models. We have used several features from this repository, such as

It is the implementation of our paper, available on Arxiv, with several improvements.

For now, our code supports the pruning of the following networks

We will see how to use this repository.

Train the base model

To train the base model, you can use the train.py file, and all the instructions can be found on the main page of Ross Wightman repository. So we will skip this part.

Pruning the model

The code to prune the model can be found in the file train_pruning.py. We will go over every of the parameters. Let start with a command that can reproduce pruning of 41% of the FLOPS of ResNet-50 and get 78.54% final accuracy. For this, we should start from the model in the latest checkpoint of Ross Wightman that achieves 79% accuracy on the test set. For now, our code supports only distributed data parallel training and not Pytorch data parallel.

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train_pruning.py \
/data/imagenet/ \
-b=192 \
--amp \
--model=resnet50 \
--lr=0.02 \
--sched=cosine \
-bp=128 \
--pruning_ratio=0.27 \
--prune \
--prune_skip \
--gamma_knowledge=20 \
--epochs=50 \

Let's go over the parameters:

Loading a pruned model

Suppose that you have trained and pruned a model, and would like to fine-tune it or load it in another repository. The function load_module_from_ckpt located in external.utils_pruning is able to adapt an unpruned model to a pruned checkpoint. You need to provide the original model as first parameter of the function, and the path of the pruned checkpoint as second parameter. The function will analyse and compare the number of channels of the convolutions, batch-norm layers and fully connected layers of the unpruned model and compare them with the one in the pruned checkpoint. You can also prune a pruned model by using the parameter --initial-checkpoint-pruned in the train_pruning.py script.

Pretrained checkpoint

All of the pretrained checkpoint for efficientNet and ResNet are located in: TODO add checkpoints

In particular, for ResNet-50 you have four checkpoints:

Benchmark

The checkpoint of the models have been integrated in the repository, you can use the pretrained option to get them.

For example, --model=efficientnet_b1_pruned --pretrained and the model will be loaded with its pretrained weight. For efficientNetb0, the pretrained weight can be found at this link.

Model to prune Pruning ratio Unpruned accuracy Pruned accuracy
EfficientNet B0 46.00% 77.30% 75.50%
EfficientNet B1 44.28% 79.20% 78.30%
EfficientNet B2 30.00% 80.30% 79.90%
EfficientNet B3 44.00% 81.7% 80.80%

In addition, we have added pruning from eca-ResNet-D models that we have integrated in the original ross wightman repository.

Model accuracy@top1 inference speed on V100 (img/sec) Inference speed on P100 (img/sec) FLOPS (Gigas) Model name
ECA Resnet Light 80.47% 2915 862 4.11 ecaresnetlight
ECA Resnet-50D 80.61% 2400 718 4.35 ecaresnet50d
ECA Resnet-50D Pruned 79.72% 3600 1200 2.53 ecaresnet50d_pruned
ECA Resnet-101D 82.19% 1476 444 8.07 ecaresnet101d
ECA Resnet-101D Pruned 80.86% 2800 1010 3.47 ecaresnet101d_pruned

In order to reproduce the pruning based on inference time (and not flops) of eca-resnet-50D, we can use the following command:

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train_pruning.py \
/data/imagenet/ \
-b=128 \
--amp \
--pretrained \
-j=8 \
--model=ecaresnet50d \
--lr=0.06 \
--sched=cosine \
-bp=100 \
--pruning_ratio=0.42 \
--use_time \
--prune \
--prune_skip \
--prune_conv1 \
--gamma_knowledge=30 \
--epochs=200 \
--smoothing=0 \