DrChainsaw / NaiveNASflux.jl

Your local Flux surgeon
MIT License
23 stars 0 forks source link

Memory leak with stateful optimizers #18

Open DrChainsaw opened 5 years ago

DrChainsaw commented 5 years ago

Stateful optimizers (like ADAM) store state per parameter array in a dict. Since parameter arrays are replaced with new instances when mutating the dict just keeps getting new entries.

This also obviously prevents the optimizers from working as they should.

DrChainsaw commented 5 years ago

Here is a start of an approach (which might not be the best all things considered):

struct StateAlign{T} <: AbstractMutableComp
    state::T
    m::AbstractMutableComp
end
StateAlign(state::T) where T = m -> StateAlign{T}(state, m)

NaiveNASlib.mutate_inputs(m::StateAlign, inputs::AbstractArray{<:Integer,1}...) = mutate(m, inputs, 1:nout(m))
NaiveNASlib.mutate_outputs(m::StateAlign, outputs) = mutate(m, Base.OneTo.(nin(m)), outputs)

function NaiveNASflux.mutate(m::StateAlign; inputs, outputs)
    ops = params(m)
    mutate(m.m, inputs=inputs, outputs=outputs)
    nps = params(m)

    for (op,np) in zip(ops, nps)
        os in keys(m.state) || continue

        os = m.state[op]
        ns = select_state(osi, op, layer(m), inputs, outputs, os)

        m.state[np] = ns
        delete!(m.state, os)
    end
end

select_state(s, op, l, ins, outs) = s
select_state(s::Tuple, op, l, ins, outs) = map(si -> select_state(si, op, l, ins, outs))
select_state(s::AbstractVector{<:Number}, op::AbstractVector{<:Number}, l, ins, outs) = NaiveNASflux.select(s, 1 => outs)
select_state(s::AbstractArray{<:Number, N}, op::AbstractArray{<:Number, N}, l, ins, outs) where N = NaiveNASflux.select(s, outdim(l) => outs, indim(l) => ins)

Probably not worth it trying to get the above to work considering https://github.com/FluxML/Flux.jl/issues/637

Not handled by the above:

  1. Copying of a model, including the optimizer. The field state must point to the state of the new optimizer
  2. Moving model between cpu and gpu. This creates new versions of all parameters. Similar problem as 1 in other words