FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

Add `trainables_nt` #175

Open CarloLucibello opened 4 months ago

CarloLucibello commented 4 months ago

This is a proposal for an alternative to destructure which doesn't completely flatten the parameters but returns a nested named tuple. The associated reconstructor can be be used on ComponentArrays as well.

darsnack commented 4 months ago

Keeping differentiability aside, is fmapstructure not sufficient because of how vectors are handled (e.g. layers in Chain)?

CarloLucibello commented 4 months ago

Exactly. And we need a nested namedtuple-only return in order to be compatible with ComponentArrays.

darsnack commented 4 months ago

What about replacing destructure with this code + the ComponentArrays construction? As opposed to adding this as a separate function. It would move a lot of the tricky stuff to ComponentArrays.

@mcabbott

mcabbott commented 4 months ago

Wait there are two big differences from fmapstructure / Flux.state:

ComponentArrays has no notion of shared parameters. That's a large part of what makes everything touching Functors tricky. (In fact the replacement of a vector with a NamedTuple opens the door to weirdness here, before you get to ComponentArrays, as you replace a mutable thing with an immutable one. Probably not in a way that matters for Flux models.)

Example with this:

julia> sh = [1f0, 2f0];

julia> ps, re = trainables_nt((sh, sh, [3,4.]))
((_1 = Float32[1.0, 2.0], _2 = Float32[1.0, 2.0], _3 = [3.0, 4.0]), Optimisers.RestructureFromNT{Tuple{Vector{Float32}, Vector{Float32}, Vector{Float64}}}((Float32[1.0, 2.0], Float32[1.0, 2.0], [3.0, 4.0])))

julia> ps._1 === ps._2
true

julia> v = ComponentVector(ps);

julia> getfield(v, :data) |> println
[1.0, 2.0, 1.0, 2.0, 3.0, 4.0]

julia> v[3] = 99;

julia> re(v)  # sharing is broken
([1.0, 2.0], [99.0, 2.0], [3.0, 4.0])

And unrelated to sharing:

julia> re(v)[1] |> eltype  # accidental promotion is back
Float64

julia> re(v)[1]   # no copy on reconstruction, but will view(::CuArray) work everywhere?
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
 1.0
 2.0

cf destructure:

julia> v2, re2 = destructure((sh, sh, [3,4.]))
([1.0, 2.0, 3.0, 4.0], Restructure(Tuple, ..., 4))

julia> v2[2] = 999;

julia> re2(v2)
(Float32[1.0, 999.0], Float32[1.0, 999.0], [3.0, 4.0])

When last I looked, ComponentArrays it also made more whole copies in the gradient.

More broadly, what's this for? Why do we care about ComponentArrays?

CarloLucibello commented 4 months ago

More broadly, what's this for? Why do we care about ComponentArrays?

I would like to have something in the v, re = destructure(model) style but for which reconstruction is copyless and it is also compatible with ComponentArrays. This is something that seems quite needed, see https://github.com/FluxML/Flux.jl/issues/2413#issuecomment-2033361707. I think we can provide it and see if it is used.

CarloLucibello commented 4 months ago

I need help with the rrule of the reconstructor. It works for named tuples but not for component arrays:

using Zygote, Optimisers, ComponentArrays, Test
m = (collect(1:3.0), collect(4:6.0))
ps, re = trainables_nt(m)
Zygote.refresh()
gps = gradient(x -> re(x)[1][2], ps)[1]
@test gps == (_1 = [0.0, 1.0, 0.0], _2 = nothing). # ok

v = ComponentVector(ps)
gv = gradient(x -> re(x)[1][2], v)[1] # this is `nothing`!!!!

The relevant rule is


function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
    model = restructure_from_nt(x, ps)
    proj_ps = ProjectTo(ps)

    function restructure_from_nt_back(Δmodel_raw)
        Δmodel = unthunk(Δmodel_raw)
        walk = RestructureFromNamedTupleBackWalk()
        function exclude(x)
            @show "exclude" x isnumeric(x)
            # i += 1
            # return i > 1
            return isnumeric(x)
        end
        Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
                    @show "fmap" Δ p

                    return Δ
                end
        Δpst = Tangent{typeof(Δps)}(; Δps...)
        @show "rrule" Δmodel x ps Δps Δpst         #here  Δp = (_1 = [0.0, 1.0, 0.0], _2 = ChainRulesCore.ZeroTangent())
        @show  typeof(Δmodel) typeof(ps) typeof(Δps)
        return (NoTangent(), NoTangent(), Δps)
        # return (NoTangent(), NoTangent(), proj_ps(Δpst))
    end
    return model, restructure_from_nt_back
end

struct RestructureFromNamedTupleBackWalk <: AbstractWalk end

function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
    @show 1 typeof(Δmodel) typeof(ps)
    Δm = make_named_tuple(Δmodel)
    @show 2 typeof(Δm) ps Δm
    Δm === nothing && return nothing
    Δm === ZeroTangent() && return ZeroTangent()
    y = mapvalue(recurse, ps, Δm)
    @show 3 typeof(Δmodel) typeof(Δm) typeof(y)
    return y
end

Why do I get nothing gradient? Am I doing something wrong with the projection?