Code release for OrthoReg. The main paper can be found at https://arxiv.org/abs/2009.05014.
OrthoReg is a regularization strategy aimed at making importance estimates in prior state-of-the-art pruning methods based on loss preservation (e.g., see Fisher pruning) more reliable and robust for large pruning ratios. We provide code for extracting Early-Bird Tickets from and for Iteratively Pruning VGG-13, MobileNet-V1, and ResNet-34.
The code requires:
Python 3.6 or higher
Pytorch 1.4 or higher
To install requirements (uses pip):
./requirements.sh
The provided modules serve the following purpose:
main.py: Provides functions for training pruned networks in general, including Early-Bird Tickets.
eval.py: Calculate train accuracy, test accuracy, degree of orthogonality (layerwise), FLOPs, and compression ratio.
imp_estimator.py: Importance estimators for different methods (Fisher, BN, TFO, GraSP, RDT, L1, and SFP).
pruner.py: Pruning engine (includes pruned networks' classes).
models.py: Model classes for VGG-13, MobileNet-V1, ResNet-34.
config.py: Hyperparameters and progress bar for training models.
Pretrained models are to be stored in the directory pretrained and pruned models will be saved in pruned_nets. We provide both minimally and fully trained models for Early-Bird Tickets and iterative pruning experiments, respectively (see below).
To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command
python main.py --ebt=True --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=25
To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command
python main.py --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=50 --n_rounds=2
--ebt=<extract_early_bird_tickets>
--model=<model_name>
--pretrained=<use_pretrained_model>
--data_path=<path_to_data>
--pruning_type=<how_to_estimate_importance>
--pruning_percent=<how_much_percent_filters_to_prune>
--n_rounds=<number_of_pruning_rounds>
--thresholds=<manual_thresholds_for_pruning>
--seed=<change_random_seed>
--only_train=<only_train_do_not_prune>
Training Settings: To change number of epochs or the learning rate schedule for training the base models or the pruned models, change the hyperparameters in config.py. By default, models are trained using SGD with momentum (0.9).
To evaluate a model (e.g., a pruned VGG-13 model), use:
python eval.py --model vgg --pruned True --model_path <path_to_model_file> --test_acc True
Summary of available options for evaluating models:
--model=<model_name>
--pruned=<evaluating_a_pruned_model>
--model_path=<path_to_model>
--data_path=<path_to_dataset>
--train_acc=<evaluate_train_accuracy>
--test_acc=<evaluate_test_accuracy>
--flops=<evaluate_flops_in_model>
--compression=<evaluate_compression_ratio>
--eval_ortho=<evaluate_degree_of_orthogonality>
We provide sample results for our code. Following are the performances of the Early-Bird Tickets drawn using different pruning methods (OrthoReg, Fisher pruning, and BN-scale based pruning) on VGG-13, MobileNet-V1, and ResNet-34 models for the CIFAR-100 dataset:
To replicate these, use seed 0.
Original models:
Model name | Accuracy |
---|---|
ResNet-34 | 73.4% |
VGG-13 | 65.5% |
MobileNet-V1 | 67.0% |
OrthoReg:
Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy |
---|---|---|---|---|---|---|---|---|
ResNet-34 | 25% | 77.4% | VGG-13 | 25% | 71.4% | MobileNet-V1 | 25% | 67.8% |
50% | 76.7% | 50% | 71.2% | 50% | 67.4% | |||
75% | 74.8% | 75% | 67.5% | 75% | 65.8% |
Fisher pruning:
Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy |
---|---|---|---|---|---|---|---|---|
ResNet-34 | 25% | 72.7% | VGG-13 | 25% | 65.4% | MobileNet-V1 | 25% | 67.9% |
50% | 72.3% | 50% | 67.1% | 50% | 67.6% | |||
75% | 71.3% | 75% | 65.3% | 75% | 65.9% |
BN-scale based:
Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy | Model name | % Pruned | Accuracy |
---|---|---|---|---|---|---|---|---|
ResNet-34 | 25% | 72.9% | VGG-13 | 25% | 65.5% | MobileNet-V1 | 25% | 67.7% |
50% | 72.1% | 50% | 65.6% | 50% | 68.3% | |||
75% | 70.1% | 75% | 64.8% | 75% | 65.5% |
We provide fully trained and minimally trained models that can be pruned using our code.
For Early-Bird Tickets experiments, minimally trained models can be found here:
For iterative pruning, fully trained models can be found here:
To use these models:
Store the fully trained models in the directory pretrained/iterative.
Store the minimally trained models in the directory pretrained/ebt.