Open DrChainsaw opened 5 years ago
Here is a start of an approach (which might not be the best all things considered):
struct StateAlign{T} <: AbstractMutableComp
state::T
m::AbstractMutableComp
end
StateAlign(state::T) where T = m -> StateAlign{T}(state, m)
NaiveNASlib.mutate_inputs(m::StateAlign, inputs::AbstractArray{<:Integer,1}...) = mutate(m, inputs, 1:nout(m))
NaiveNASlib.mutate_outputs(m::StateAlign, outputs) = mutate(m, Base.OneTo.(nin(m)), outputs)
function NaiveNASflux.mutate(m::StateAlign; inputs, outputs)
ops = params(m)
mutate(m.m, inputs=inputs, outputs=outputs)
nps = params(m)
for (op,np) in zip(ops, nps)
os in keys(m.state) || continue
os = m.state[op]
ns = select_state(osi, op, layer(m), inputs, outputs, os)
m.state[np] = ns
delete!(m.state, os)
end
end
select_state(s, op, l, ins, outs) = s
select_state(s::Tuple, op, l, ins, outs) = map(si -> select_state(si, op, l, ins, outs))
select_state(s::AbstractVector{<:Number}, op::AbstractVector{<:Number}, l, ins, outs) = NaiveNASflux.select(s, 1 => outs)
select_state(s::AbstractArray{<:Number, N}, op::AbstractArray{<:Number, N}, l, ins, outs) where N = NaiveNASflux.select(s, outdim(l) => outs, indim(l) => ins)
Probably not worth it trying to get the above to work considering https://github.com/FluxML/Flux.jl/issues/637
Not handled by the above:
Stateful optimizers (like ADAM) store state per parameter array in a dict. Since parameter arrays are replaced with new instances when mutating the dict just keeps getting new entries.
This also obviously prevents the optimizers from working as they should.