This script implements and visualizes the performance the following algorithms, based on the MNIST hand-written digit recognition dataset:
All the detail of the algorithms are described in the blog post An overview of gradient descent optimization algorithms by Sebastian Ruder
The MNIST dataset contains 60,000
samples for training and 10,000
for validating. It is naturally divided into 10 classes
corresponding to digit 0 to 9. The amount of samples for each class is well-balanced. Each digit image is pre-processed and rescaled into a 28*28 gray scale array, ranging from 0 to 255
A traditional 3 layer neural network is adopted, with 28*28 input units, 25 hidden units and a softmax output layer
Here are the training and validating accuracy of each algorithm, with 30 epochs
and 100 mini-batch
:
95.36% vs 94.06%
97.52% vs 94.88%
97.47% vs 94.33%
96.17% vs 93.95%
94.65% vs 93.84%
96.35% vs 94.51%
96.54% vs 94.12%
Here is the visualization of cost decreasing w.r.t each mini-batch within the first 10 epochs:
The variants of gradient descent algorithm can be roughly divided into 2 types: Momentum-like SGD
and Adaptive learning rate SGD
.
As mentioned in the blog post Adaptive learning rate SGD
is suitable for large-scale sparse optimization problem (e.g, predict CTR). While in this case, data is not sparse, Momentum-like SGD
significantly outperforms the others.