Repository with code to reproduce the results and checkpoints for compressed networks in our paper on novel pruning techniques with robust training. This repository supports all four robust training objectives: iterative adversarial training, randomized smoothing, MixTrain, and CROWN-IBP.
Following is a snippet of key results where we showed that accounting the robust training objective in pruning strategy can lead to large gains in the robustness of pruned networks.
In particular, the improvement arises from letting the robust training objective controlling which connections to prune. In almost all cases, it prefers to pruned certain high-magnitude weights while preserving other small magnitude weights, which is orthogonal to the strategy in well-established least-weight magnitude (LWM) based pruning.
April 30, 2020: Checkpoints for WRN-28-10, a common network for benchmarking adv. robustness | 90% pruned with proposed technique | Benign test accuracy = 88.97% , PGD-50 test accuracy = 62.24%.
May 23, 2020: Our WRN-28-10 network with 90% connection pruning comes in the second place in the auto-attack robustness benchmark.
Let's start by installing all dependencies.
pip install -r requirement.txt
We will use train.py
for all our experiments on the CIFAR-10 and SVHN dataset. For ImageNet, we will use train_imagenet.py
. It provides the flexibility to work with pre-training, pruning, and Finetuning steps along with different training objectives.
exp_mode
: select from pretrain, prune, finetunetrainer
: benign (base), iterative adversarial training (adv), randomized smoothing (smooth), mix train, crown-imp --dataset
: cifar10, svhn, imagenetFollowing this work, we modify the convolution layer to have an internal mask. We can use a masked convolution layer with --layer-type=subnet
. The argument k
refers to the fraction of non-pruned connections.
In pre-training, we train the networks with k=1
i.e, without pruning. Following example pre-train a WRN-28-4 network with adversarial training.
python train.py --arch wrn_28_4 --exp-mode pretrain --configs configs/configs.yml --trainer adv --val_method adv --k 1.0
In pruning steps, we will essentially freeze weights of the network and only update the importance scores. The following command will prune the pre-trained WRN-28-4 network to 99% pruning ratio.
python train.py --arch wrn_28_4 --exp-mode prune --configs configs.yml --trainer adv --val_method adv --k 0.01 --scaled-score-init --source-net pretrained_net_checkpoint_path --epochs 20 --save-dense
It used 20 epochs to optimize for better-pruned networks following the proposed scaled initialization of importance scores. It also saves a checkpoint of pruned with dense layers i.e, throws aways masks form each layer after multiplying it with weights. These dense checkpoints are helpful as they are directly loaded in a model based on standard layers from torch.nn.
In the fine-tuning step, we will update the non-pruned weights but freeze the importance scores. For correct results, we must select the same pruning ratio as the pruning step.
python train.py --arch wrn_28_4 --exp-mode finetune --configs configs.yml --trainer adv --val_method adv --k 0.01 --source-net pruned_net_checkpoint_path --save-dense --lr 0.01
We use a single shot pruning approach where we prune the desired number of connections after pre-training in a single step. After that, the network is fine-tuned with a similar configuration as above.
python train.py --arch wrn_28_4 --exp-mode finetune --configs configs.yml --trainer adv --val_method adv --k 0.01 --source-net pretrained_net_checkpoint_path --save-dense --lr 0.01 --scaled-score-init
The only difference from fine-tuning from previous steps is the now we initialized the importance scores with proposed scaling. This scheme effectively prunes the connection with the lowest magnitude at the start. Since the importance scores are not updated with fine-tuning, this will effectively implement the LWM based pruning.
We can use the following scripts to obtain compact networks from both LWM and proposed pruning techniques.
get_compact_net_adv_train.sh
: Compact networks with iterative adversarial training. get_compact_net_rand_smoothing.sh
Compact networks with randomized smoothing.get_compact_net_mixtrain.sh
Compact networks with MixTrain. get_compact_net_crown-ibp.sh
Compact networks with CROWN-IBP.It is curious to ask whether pruning certain connections itself can induce robustness in a network. In particular, given a non-robust network, does there exist a highly robust subnetwork? We find that indeed there exist such robust subnetworks with a non-trivial amount of robustness. Here is an example to reproduce these results:
python train.py --arch wrn_28_4 --configs configs.yml --trainer adv --val-method adv --k 0.5 --source-net pretrained_non-robust-net_checkpoint_path
Thus, given the checkpoint path of a non-robust network, it aims to find a sub-network with half the connections but having high empirical robust accuracy. We can similarly optimize for verifiably robust accuracy by selecting --trainer
from smooth | mixtrain | crown-ibp
, with using proper configs for each.
We are releasing pruned models for all three pruning ratios (90, 95, 99%) for all three datasets used in the paper. In case you want to compare some additional property of pruned models with a baseline, we are also releasing non-pruned i.e., pre-trained networks. Note that, we use input normalization only for the ImageNet dataset. For each model, we are releasing two checkpoints: one with masked layers and other with dense layers. Note that the numbers from these checkpoints might differ a little bit from the ones reported in the paper.
Dataset | Architecture | Pre-trained (0%) | 90% pruned | 95% pruned | 99% pruned |
---|---|---|---|---|---|
CIFAR10 | VGG16 | ckpt | ckpt | ckpt | ckpt |
CIFAR10 | WRN-28-4 | ckpt | ckpt | ckpt | ckpt |
SVHN | VGG16 | ckpt | ckpt | ckpt | ckpt |
SVHN | WRN-28-4 | ckpt | ckpt | ckpt | ckpt |
Dataset | Architecture | Pre-trained (0%) | 90% pruned | 95% pruned | 99% pruned |
---|---|---|---|---|---|
CIFAR10 | VGG16 | ckpt | ckpt | ckpt | ckpt |
CIFAR10 | WRN-28-4 | ckpt | ckpt | ckpt | ckpt |
SVHN | VGG16 | ckpt | ckpt | ckpt | ckpt |
SVHN | WRN-28-4 | ckpt | ckpt | ckpt | ckpt |
Pre-trained (0%) | 95% pruned | 99% pruned |
---|---|---|
ckpt | ckpt | ckpt |
Some of the code in this repository is based on the following amazing works.
If you find this work helpful, consider citing it.
@article{sehwag2020hydra,
title={Hydra: Pruning adversarially robust neural networks},
author={Sehwag, Vikash and Wang, Shiqi and Mittal, Prateek and Jana, Suman},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}