KGrewal1 / optimisers

A collection of optimisers for use with candle
MIT License
26 stars 2 forks source link

Candle Optimisers

License: MIT codecov Tests Tests Latest version Documentation

A crate for optimisers for use with candle, the minimalist ML framework

Optimisers implemented are:

Adaptive methods:

These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).

Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in Decoupled Weight Decay Regularization, in addition to the standard weight decay as implemented in pytorch.

Pseudosecond order methods:

This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function

Examples

There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely:

cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025

For even faster training try:

cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025

to use the cuda backend.

Usage

cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers

To do

Currently unimplemented from pytorch:

Notes

For development, to track state of pytorch methods, use:

print(optimiser.state)