This project consists of an utility wrapper around PyTorch and Brevitas to specify the network, dataset and parameters to train with.
You can run the training of LeNet
on the MNIST
dataset as follows:
$ PYTORCH_JIT=1 python nn_benchmark/main.py --network LeNet --dataset MNIST --epochs 3
Or its quantized version QuantLeNet
on the CIFAR-10
dataset with bit-width of (4,4,8) (corresponding to activation, weight and input) :
$ PYTORCH_JIT=1 python nn_benchmark/main.py --network QuantLeNet --dataset CIFAR10 --epochs 3 \
$ --acq 4 --weq 4 --inq 8
The results can be observed under the experiments folder.
You can then evaluate your network with the following command:
$ python nn_benchmark/main.py --network LeNet --dataset MNIST --evaluate --resume ./experiments/<your_folder>/checkpoints/best.tar
The following networks are supported:
Their quantized counterparts are available as well:
The following datasets can be used:
I worked on the project through a virtual environment with virtualenvwrapper
and I highly recommend to do so as well. However, whether or not you are in a
virtual environment, the installation proceeds as follows:
For downloading and installing the source code of the project:
$ cd <directory you want to install to>
$ git clone https://github.com/QDucasse/nn_benchmark
$ python setup.py install
For downloading and installing the source code of the project in a new virtual environment:
Download of the source code & Creation of the virtual environment
$ cd <directory you want to install to>
$ git clone https://github.com/QDucasse/nn_benchmark
$ cd nn_benchmark
$ mkvirtualenv -a . -r requirements.txt VIRTUALENV_NAME
Launch of the environment & installation of the project
$ workon VIRTUALENV_NAME
$ pip install -e .
Finally, whether you chose the first or second option, you will need brevitas if you want to use quantized networks. The installation is better performed from source and can be done as follows (in your native or virtual environment):
$ git clone https://github.com/Xilinx/brevitas.git
$ cd brevitas
$ pip install -e .
Quick presentation of the different modules of the project:
CLI
, Logger
, Plotter
and Trainer
.PyTorch
modules. This package contains specific modules (e.g. TensorNorm
), dataset (e.g. GTSRB
) or loss functions (e.g. SquaredHinge
) that can be used as drop-in replacements for their PyTorch
homologues.--network
flag. If the suffix Quant
is present, this means the network is quantized and the three following precisions can be specified: weight_bit_width
, precision of the weights ; act_bit_width
precision of the activation functions ; in_bit_width
, input precision (this is useful to keep a higher precision at the beginning).This project uses the following external libraries:
If installed as specified above, the requirements are stated in the requirements.txt
file
and therefore automatically installed.
However, you can install each of them separately with the command (except for brevitas
, please follow the installation from source provided at the end of the installation paragraph):
$ pip install <library>
Trainer
/Logger
logicLeNet
network training on MNIST
LeNet5
, MobilenetV1
, VGG11
, VGG13
, VGG16
, VGG19
network architecturesPyTorch
extensions with GTSRB
dataset and another loss function (SqrHinge
)Plotter
All tests are written to work with nose
and/or pytest
. Just type pytest
or
nosetests
as a command line in the project. Every test file can still be launched
by executing the testfile itself.
$ python nn_benchmark/tests/chosentest.py
$ pytest