Closed chengchingwen closed 1 year ago
I'd like to implement this, but I'm not sure I understood the proposed implementation.
First of all, is the interface supposed to be
model = ...
opt = OptimiserChain(AccumGrad(K), Adam())
optstate = Optimisers.setup(opt, model)
for (k, batch) in enumerate(dataloader)
grad = ...
Optimisers.update!(optstate, model, grad)
end
with the following behavior for update!
:
k % K != 0
the grad
is accumulated in the state of AccumGrad
and nothing else happens (the Adam rule is not called and the model is not updated)k % K == 0
the grad
is accumulated in the state of AccumGrad
, then the accumulated gradient is fed to the Adam rule, and the output is used to update the model.?
I don't quite see how to achieve this behavior with the proposed negative gradient trick.
I think that one can instead have a very simple implementation by wrapping AccumGrad
around an optimizer
opt = AccumGrad(Adam(), K)
and then overloading Optimisers.update!
.
I agree that the cleaner implementation here is to create a wrapper ("higher order") optimizer that internally accumulates the gradients for K calls then applying the internal rule once. But the function being overloaded should be apply!
not update!
. You can see here for a reference implementation of a similar wrapper.
Don't forget the extra "divide by k
" before update.
But the function being overloaded should be apply! not update!
But apply!
should return a dx
. Creating a zero array is wasteful. Is returning nothing
supported?
Don't forget the extra "divide by k" before update.
Is it ok to be for it to be plain division or we should reweight with the batch size (of which Optimisers knows nothing)?
Plain division should be fine. I don't think we are able to get the batch size in Optimisers.
But apply! should return a dx. Creating a zero array is wasteful. Is returning nothing supported?
Not right now, but we can add subtract(x, ::Nothing)
easily.
Motivation and description
Gradient accumulation is an important technique when the model does not fit in the memory with large batch size.
Possible Implementation
Suggested by @ToucheSir,