FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.47k stars 604 forks source link

Suggestion: Bounds for stochastic gradient descent loss fluctuations #1000

Closed cems2 closed 4 years ago

cems2 commented 4 years ago

Problem: In Stochastic Gradient Descent sometimes the descent will suddenly go unstable even with a step size set as low as you can go (limited by FLoat32 precision). This is not preventable with the current Descent() function because with mini-batches the loss is not assured to monotonically decrease, so the Descent function can not be strictly descending and thus it must accept moves that raise the loss functions value from the previous iterations.

Note: This problem is often not noticeable, but will frequently manifest when the loss function is exponentially sensitive to the input-data values that vary between minibatches as is the case for many NeuralODE layers and some exponential linear units.

Suggested solution: StrictDescent(loss_bounds=(lower,upper) )::Boolean this would do the following, if after a call to this the loss remains above the upperLoss then it leaves the parameters unchanged and returns false. otherwise if the loss is in bounds it returns true.

In use, the user will after some period of successful gradient descent seen some good loss values, as well as observe the loss fluctuations from minbatch to minibatch. The user can cache a copy of those parameters, then bound the gradient descent to not wander off of that good loss value by more than the previously observed minibatch fluctuation value.

the booloean return alerts the user this happening. This lets the user decide if they want to just ignore this (soft) exception and just move on to the next minibatch and hope for better results, or to trap the exception and reset the parameters to some long ago state where the best loss on a larger training set ( or test set or holdout) was previously observed. If they ignore it and proceed the last minibatch's parameter changed are discarded as though it never happened.

I have also included a lower bound for completeness. this may be slightly less useful but it is also possible when doing minibatches to have the descent be too agressive especially when the batch is small. So this limits how far the descent can move in one batch.

Note: these bounds are different than the step size parameter that would also be used here. The step size is a limiter on how fast the paramters change in an interation . Here we are regulating how much the loss can change from iteration.

with exponentially sensitive parameters the step size may need to reach sizes so small that it exceeds the machine precision for the update of the parameters! Thus limiting the loss instead should alleviate that.

this should make the process more stable.

CarloLucibello commented 4 years ago

This type of heuristics could well be implemented on the user side or as extra packages on top of Flux, we should just implement well-consolidated techniques here