jotaf98 / curveball

Second-order optimiser for deep networks
76 stars 6 forks source link

CurveBall

Sandy Koufax

This is the accompanying code repository for the paper:

João F. Henriques, Sebastien Ehrhardt, Samuel Albanie, Andrea Vedaldi "Small steps and giant leaps: Minimal Newton solvers for Deep Learning" arXiv preprint, 2018

Warning

This code is undergoing refactoring, which may introduce subtle bugs. Also, be aware that our implementation of forward-mode automatic differentiation (FMAD) could be more efficient, when compared to standard forward/back-propagation operations (CuDNN). We expect to improve this over time.

Installation

Requirements:

For speed, the forward-mode automatic differentiation (FMAD) is not all pure Matlab, but uses a couple of custom CUDA kernels (batch-norm and max-pooling switches). This requires compilation. First call compile.sh with your matlab path as argument by calling compile.sh. Then compile the rest of the methods by calling compile_fmad.

Training

The main function is called training. It supports the models (VGG/AlexNet/ResNet/etc), datasets (MNIST/CIFAR/ImageNet) and solvers (SGD/Adam/etc) defined in AutoNN. It also supports our new solver, called CurveBall. Note that not all models may work (due to undefined ops in the FMAD routine).

The first argument is an experiment name (subdirectory to store results), followed by name-value pairs. By default, the results are stored in <matconvnet>/data/curveball. The datasets are downloaded by default to <matconvnet>/data/<datasetname>. These can be overriden, but it's perhaps more practical to symlink data to a desired data folder. One important parameter is the 'gpu', which defines the GPU index to use for training. By default the first GPU is used.

The full parameter list is at the top of the training.m file. A few examples follow.

Results for a given dataset can be plotted together and compared using plot_results.