Closed mohamed82008 closed 2 years ago
cc: @ablaom
mohamed82008 Great catch! I agree the implementation is incorrect. I don't suppose you would consider making a PR to resolve?
I can make a PR just wanted to get a pre-approval.
Awesome. I promise an expedited review.
PR opened.
In the following line, I believe the penalty should be multiplied by the relative batch size (batch size / dataset size) such that the expectation of the stochastic gradient evaluated at the same parameter values is proportional to the true gradient.
https://github.com/FluxML/MLJFlux.jl/blob/4aae8c25df008aa5980f3031407c361628ddd6b0/src/core.jl#L39
You all probably know this already but here is the mathematical justification anyways. Let the full batch loss be
L(w)
for some parameter valuesw
and let the mini-batch loss be the conditional random variablel | w
. Additionally, let the batch size ben
and the full dataset size beN
. The expected value ofl | w
is:Stochastic gradient descent relies on the fact that the expected value of the stochastic objective is proportional to the true full batch objective. Let the full batch objective be:
Now let's consider the following 2 mini-batch objectives
o | w
and their expectations:o | w = (l | w) + penalty(w)
->E(o | w) = E(l | w) + penalty(w) = n/N * L(w) + penalty(w)
o | w = (l | w) + n/N * penalty(w)
->E(o | w) = E(l | w) + n/N * penalty(w) = n/N * L(w) + n/N * penalty(w) = n/N * O(w)
The second mini-batch objective is the one whose expectation is proportional to the full batch objective
O(w)
. The first one over-penalises the weights by 1 over the relative batch size.