FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

`reset!(optimiser_state)` #163

Open Vilin97 opened 11 months ago

Vilin97 commented 11 months ago

Motivation and description

In my application I do 25 steps of gradient descent update! steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model) every time. Unfortunately, Flux.setup is type-unstable https://github.com/FluxML/Optimisers.jl/issues/162. It would be great to have a function reset!(optimiser_state) that resets the momenta. Maybe a more stringent requirement is that

state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)

holds.

Possible Implementation

Below is an implementation for Adam.

function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
    leaf.state[1] .= 0
    leaf.state[2] .= 0
    leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
    nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
    for layer in state.layers
        reset!(layer.weight)
        reset!(layer.bias)
    end
    nothing
end
mcabbott commented 11 months ago

One possible design is this:

reset!(tree) = foreach(reset!, tree)
reset!(ℓ::Leaf) = ℓ.state = reset!(ℓ.rule, ℓ.state)

reset!(::AbstractRule, ::Nothing) = nothing
reset!(rule::AbstractRule, state) = throw(ArgumentError("""reset! does not now how to handle this rule."))

Then rules need to opt-in by defining a method of 2-arg reset!... with some fill!! which allows for immutable arrays?

reset!(rule::Adam, (mt, vt, βt)) = (fill!!(mt, 0), fill!!(vt, 0), rule.beta)

We can't easily fall back to calling init again for unknown rules, as we don't have the original parameters x here.

Falling back to zero like this might be OK for built-in rules like Momentum etc, but could be wrong for user-defined rules... probably we shouldn't:

reset!(rule::AbstractRule, state::AbstractArray) = fill!!(state, 0)
ToucheSir commented 11 months ago

We could always make reset! take the parameter tree of xs too, but that may come at the cost of sacrificing type stability.