[]() []() []()
This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.
pip3 install -r requirements.txt
python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
--prune_type
: Type of pruning
lt
- Lottery Ticket Hypothesis, reinit
- Random reinitializationlt
--arch_type
: Type of architecture
fc1
- Simple fully connected network, lenet5
- LeNet5, AlexNet
- AlexNet, resnet18
- Resnet18, vgg16
- VGG16 fc1
--dataset
: Choice of dataset
mnist
, fashionmnist
, cifar10
, cifar100
mnist
--prune_percent
: Percentage of weight to be pruned after each cycle.
10
--prune_iterations
: Number of cycle of pruning that should be done.
35
--lr
: Learning rate
1.2e-3
--batch_size
: Batch size
60
--end_iter
: Number of Epochs
100
--print_freq
: Frequency for printing accuracy and loss
1
--valid_freq
: Frequency for Validation
1
--gpu
: Decide Which GPU the program should use
0
new_model
with mnist
dataset compatibility.
/archs/mnist/
directory and create a file new_model.py
.new_model.py
.new_model.py
matches with the corresponding dataset that you are adding (in this case, it is mnist
).main.py
and go to line 36
and look for the comment # Data Loader
. Now find your corresponding dataset (in this case, mnist
) and add new_model
at the end of the line from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
.line 82
and add the following to it :
elif args.arch_type == "new_model":
model = new_model.new_model_name().to(device)
Here, new_model_name()
is the name of the model that you have given inside new_model.py
.
new_dataset
with fc1
architecture compatibility.
/archs
and create a directory named new_dataset
.and add a file named
fc1.py` or copy paste it from existing dataset folder.new_model.py
matches with the corresponding dataset that you are adding (in this case, it is new_dataset
).main.py
and goto line 58
and add the following to it :
elif args.dataset == "cifar100":
traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform)
testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
Note that as of now, you can only add dataset that are natively available in Pytorch.
prune_type
?combine_plots.py
and add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed the main.py
code for those dataset/archs and produced the weights).python3 combine_plots.py
./plots/lt/combined_plots/
to see the graphs.Kindly raise an issue if you have any problem with the instructions.
fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | |
---|---|---|---|---|---|
MNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
CIFAR10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
FashionMNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
CIFAR100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│ ├── cifar10
│ │ ├── AlexNet.py
│ │ ├── densenet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ ├── cifar100
│ │ ├── AlexNet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ └── mnist
│ ├── AlexNet.py
│ ├── fc1.py
│ ├── LeNet5.py
│ ├── resnet.py
│ └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py
Parts of code were borrowed from ktkth5.
Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.