This is my implementation of an approach to neural network model interpretability issue described in the paper "Distilling a Neural Network Into a Soft Decision Tree" by Nicholas Frosst and Geoffrey Hinton.
My attempt to replicate the results reported in the paper along with demonstration of how this implementation can be used on dataset MNIST for training NN model, distilling it into a Soft Binary Decision Tree (SBDT) model and visualizing it, can be found in mnist.ipynb.
Remaining content is documented by the table (and some of it also hopefully by itself).
Location | Content description |
---|---|
models/tree.py | Implementation of SBDT model in tf.keras with all details as stated in the paper. Some parts such as loss regularization term calculation are done in pure TensorFlow. The rest is encapsulated into keras custom layers. |
Due to lack of keras' flexibility, SBDT model is not save-able using keras' serialization methods so tf.Saver is used instead. This also means that keras callback for ModelCheckpoint won't work with this implementation (unless the model is re-written to avoid using tf.Tensor objects as keras Layer arguments). |
|
Due to use of moving averages in calculation of penalty terms, custom two-step initialization of model parameters is required and model training (evaluation of tree.loss tensorflow graph op) is batch_size -dependent. This also means, that batch_size % train_data_size == 0 must hold, otherwise shape mismatch will be encountered at the end of training epoch (keras will feed the remainder as a smaller minibatch). |
|
models/convnet.py | Implementation of basic convolutional NN model in tf.keras as given by keras MNIST-CNN basic example. |
models/utils.py | Utility functions for re-setting tensorflow session and visualizing model parameters. |
makegif.sh | Converts directory of images into animation and labels frames based on folder name and file names. See mnist.ipynb for exemplary usage. |
assets | Saved model checkpoints (for easier reproducibility) and illustrative images / animations. |
Git Large File Storage (git-lfs) is used to store model checkpoints and large gif animation files in the repo. To install it, run
git lfs install
Code was tested on Python 3.5 with TensorFlow version 1.10.0 and matplotlib version 2.1.0.
To install the required python packages, run
pip3 install -r requirements.txt
For GPU support, use
pip3 install -r requirements_gpu.txt
The table below summarizes results as produced and presented in mnist.ipynb, but no exhaustive hyperparameter search was performed, so there is space for improvement.
Model | Depth | Labels | Batch size | Epochs | Accuracy |
---|---|---|---|---|---|
ConvNet | - | hard | 16 | 12 | 99.29% |
Tree (normal) | 4 | hard | 4 | 40 | 80.88% |
Tree (distill) | 4 | soft | 4 | 40 | 90.71% |
Just a quick taste of what's inside. Detailed instructions for how to read these visualizations are in the mnist.ipynb notebook.