EkdeepSLubana / OrthoReg

MIT License
9 stars 2 forks source link

OrthoReg: Robust Network Pruning Using Orthogonality Regularization

Code release for OrthoReg. The main paper can be found at https://arxiv.org/abs/2009.05014.

Brief Summary

OrthoReg is a regularization strategy aimed at making importance estimates in prior state-of-the-art pruning methods based on loss preservation (e.g., see Fisher pruning) more reliable and robust for large pruning ratios. We provide code for extracting Early-Bird Tickets from and for Iteratively Pruning VGG-13, MobileNet-V1, and ResNet-34.

The code requires:

Requirements

To install requirements (uses pip):

./requirements.sh

Organization

The provided modules serve the following purpose:

Pretrained models are to be stored in the directory pretrained and pruned models will be saved in pruned_nets. We provide both minimally and fully trained models for Early-Bird Tickets and iterative pruning experiments, respectively (see below).

Extracting Early-Bird Tickets

To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command

python main.py --ebt=True --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=25

Iterative Pruning

To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command

python main.py --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=50 --n_rounds=2

Summary of available options

--ebt=<extract_early_bird_tickets>

--model=<model_name>

--pretrained=<use_pretrained_model>

--data_path=<path_to_data>

--pruning_type=<how_to_estimate_importance>

--pruning_percent=<how_much_percent_filters_to_prune>

--n_rounds=<number_of_pruning_rounds>

--thresholds=<manual_thresholds_for_pruning>

--seed=<change_random_seed>

--only_train=<only_train_do_not_prune>

Training Settings: To change number of epochs or the learning rate schedule for training the base models or the pruned models, change the hyperparameters in config.py. By default, models are trained using SGD with momentum (0.9).

Evaluation

To evaluate a model (e.g., a pruned VGG-13 model), use:

python eval.py --model vgg --pruned True --model_path <path_to_model_file> --test_acc True

Summary of available options for evaluating models:

--model=<model_name>

--pruned=<evaluating_a_pruned_model>

--model_path=<path_to_model>

--data_path=<path_to_dataset>

--train_acc=<evaluate_train_accuracy>

--test_acc=<evaluate_test_accuracy>

--flops=<evaluate_flops_in_model>

--compression=<evaluate_compression_ratio>

--eval_ortho=<evaluate_degree_of_orthogonality>


Results

We provide sample results for our code. Following are the performances of the Early-Bird Tickets drawn using different pruning methods (OrthoReg, Fisher pruning, and BN-scale based pruning) on VGG-13, MobileNet-V1, and ResNet-34 models for the CIFAR-100 dataset:

To replicate these, use seed 0.

Pre-trained Models

We provide fully trained and minimally trained models that can be pruned using our code.

For Early-Bird Tickets experiments, minimally trained models can be found here:

For iterative pruning, fully trained models can be found here:

To use these models:

Note