deeplearning4j / deeplearning4j

Suite of tools for deploying and training deep learning models using the JVM. Highlights include model import for keras, tensorflow, and onnx/pytorch, a modular and tiny c++ library for running math code and a java based math library on top of the core c++ library. Also includes samediff: a pytorch/tensorflow like library for running deep learn...
http://deeplearning4j.konduit.ai
Apache License 2.0
13.6k stars 3.83k forks source link

Stochastic weight averaging #5045

Open tom-adsfund opened 6 years ago

tom-adsfund commented 6 years ago

SWA (https://arxiv.org/pdf/1803.05407.pdf , also mentioned on dl4j twitter), takes a running average of the weights of the network during training. This allows the weights to move from the boundary of the ball found by SGD to inside that ball, which has better test accuracy (see the paper for more).

Implementation in theory is straight-forward, because weights are simply averaged using iterative averaging. But to implement in DL4J I think we'd need to have each layer assigned a unique identifier, so it could be cloned and then the weights averaged.

I open this issue to see what people think about that problem, and also to track any progress implementing this.

tom-adsfund commented 6 years ago

Looking at the source/docs, both ComputationGraph and MultiLayerNetwork index the layers in an array, so actually this should be very easy to implement. A SWA class could take the arrays from the network and that network's clone in the SWA constructor and have an average() method.

AlexDBlack commented 6 years ago

OK, I finally got to reading the paper properly.

The cyclical LR schedules would be implemented separately, using our existing ISchedule schedules functionality: https://github.com/deeplearning4j/deeplearning4j/issues/5057

As for the averaging: the I would implement this is as a TrainingListener. The listener would keep a w_{swa} parameter vector internally, and update it periodically using model.params() and the equation from algorithm 1. So our space overhead is just O(N) for N parameters.

Then, we simply need to tell the listener somehow to set the final network parameters to w_{swa} and we're done. That could be a single method call (StochasticWeightAveragingListener.setFinalParams(Model) perhaps?); the only thing I don't like about that is the fact that users have to manually remember to do so (but I don't see any other option for that).

tom-adsfund commented 6 years ago

@AlexDBlack So does the model.params give a 'view' on all the params from the Layers?

AlexDBlack commented 6 years ago

MultiLayerNetwork/ComputationGraph.params() is the primary array for parameters (it's a 1d array of length [1, numParams])... all other parameters (i.e., those for each layer) are subsets of this (and reshaped to [nIn, nOut] or whatever as required).

So (for performance reasons) we want to deal with this one single params() array only, instead of dealing with the parameters on a per-layer basis.

Edit: put another way, if you do params().addi(x) then the parameters of all layers will be modified.

tom-adsfund commented 6 years ago

OK, I'll just go with what you say, since I have no idea of the internals!

So yes, that looks easy. This will be such a powerful addition.

The reason I immediately thought of twin networks is that you could do eval on both. Because really you care about the test loss on the SWA network. But I guess you could still clone the network and set the parameters from the SWAListener.