FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
75 stars 22 forks source link

Interface for gradient accumulation #130

Closed chengchingwen closed 1 year ago

chengchingwen commented 1 year ago

Motivation and description

Gradient accumulation is an important technique when the model does not fit in the memory with large batch size.

Possible Implementation

Suggested by @ToucheSir,

defining an AbstractRule for accumulation which just negates the gradient. Then when when update! is called, https://github.com/FluxML/Optimisers.jl/blob/master/src/interface.jl#L98 will cancel out the negative. To make this work, run setup as usual on the model using this grad accum rule, then use that tree to update!(accum_tree, accum_grads, new_grads) instead. Once you've finished accumulating, just run update!(opt_state_tree, model, accum_grads) as you would without accumulation.

CarloLucibello commented 1 year ago

I'd like to implement this, but I'm not sure I understood the proposed implementation.

First of all, is the interface supposed to be

model = ...
opt = OptimiserChain(AccumGrad(K), Adam())
optstate = Optimisers.setup(opt, model) 
for (k, batch) in enumerate(dataloader)
  grad = ...
  Optimisers.update!(optstate, model, grad)
end

with the following behavior for update!:

?

I don't quite see how to achieve this behavior with the proposed negative gradient trick.

I think that one can instead have a very simple implementation by wrapping AccumGrad around an optimizer

opt = AccumGrad(Adam(), K)

and then overloading Optimisers.update!.

darsnack commented 1 year ago

I agree that the cleaner implementation here is to create a wrapper ("higher order") optimizer that internally accumulates the gradients for K calls then applying the internal rule once. But the function being overloaded should be apply! not update!. You can see here for a reference implementation of a similar wrapper.

chengchingwen commented 1 year ago

Don't forget the extra "divide by k" before update.

CarloLucibello commented 1 year ago

But the function being overloaded should be apply! not update!

But apply! should return a dx. Creating a zero array is wasteful. Is returning nothing supported?

CarloLucibello commented 1 year ago

Don't forget the extra "divide by k" before update.

Is it ok to be for it to be plain division or we should reweight with the batch size (of which Optimisers knows nothing)?

chengchingwen commented 1 year ago

Plain division should be fine. I don't think we are able to get the batch size in Optimisers.

darsnack commented 1 year ago

But apply! should return a dx. Creating a zero array is wasteful. Is returning nothing supported?

Not right now, but we can add subtract(x, ::Nothing) easily.