FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.47k stars 602 forks source link

Hessian vector products with moderately complex models #1813

Open colinxs opened 4 years ago

colinxs commented 4 years ago

I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.

Following https://github.com/FluxML/Zygote.jl/issues/115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array, but the following appears to have issues inferring the type of Dual.

Any help would be greatly appreciated! :)

# Zygote v0.4.1, Flux v0.10.0, ForwardDiff v0.10.7, DiffRules 0.1.0, ZygoteRules 0.2.0

# Julia Version 1.3.0
# Commit 46ce4d7933 (2019-11-26 06:09 UTC)
# Platform Info:
#   OS: Linux (x86_64-pc-linux-gnu)
#   CPU: Intel(R) Core(TM) i9-7960X CPU @ 2.80GHz
#   WORD_SIZE: 64
#   LIBM: libopenlibm
#   LLVM: libLLVM-6.0.1 (ORCJIT, skylake)

using Flux, ForwardDiff, Zygote
using LinearAlgebra

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    ll = -length(P.logstd) * log(2pi) / 2
    for i = 1:length(action)
        ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
        ll -= P.logstd[i]
    end
    ll
end

function flatgrad(f, ps)
    gs = Zygote.gradient(f, ps)
    vcat([vec(gs[p]) for p in ps]...)
end

Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32

function hessian_vector_product(f,ps,v)
    g = let f=f
        ps -> flatgrad(f, ps)::Vector{Float32}
    end
    gvp = let g=g, v=v
        ps -> (g(ps)⋅v)::Vector{Float32}
    end
    Zygote.forward_jacobian(gvp, ps)[2]
end

function test()
    policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
    ps = Flux.params(policy)
    v = rand(Float32, sum(length, ps))
    feat = rand(Float32, 4)
    act = rand(Float32, 2)
    f = let policy=policy, feat=feat, act=act
        () -> loglikelihood(policy, feat, act)
    end
    hessian_vector_product(f, ps, v)
end

Calling test() yields:

an_dual.
Stacktrace:
 [1] throw_cannot_dual(::Type) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:36
 [2] ForwardDiff.Dual{Nothing,Any,12}(::Array{Float32,2}, ::ForwardDiff.Partials{12,Any}) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:18
 [3] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:55 [inlined]
 [4] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:62 [inlined]
 [5] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:68 [inlined]
 [6] (::Zygote.var"#1565#1567"{12,Int64})(::Array{Float32,2}, ::Int64) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:8
 [7] (::Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}})(::Tuple{Array{Float32,2},Int64}) at ./generator.jl:36
 [8] iterate at ./generator.jl:47 [inlined]
 [9] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Params,UnitRange{Int64}}},Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}}}) at ./array.jl:622
 [10] map at ./abstractarray.jl:2155 [inlined]
 [11] seed at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:7 [inlined] (repeats 2 times)
 [12] forward_jacobian(::var"#340#342"{var"#339#341"{var"#343#344"{DiagGaussianPolicy{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1}},
Array{Float32,1},Array{Float32,1}}},Array{Float32,1}}, ::Params, ::Val{12}) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:23
 [13] forward_jacobian(::Function, ::Params) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:40
 [14] hessian_vector_product(::Function, ::Params, ::Array{Float32,1}) at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:52
 [15] test() at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:64
 [16] top-level scope at REPL[41]:1
colinxs commented 4 years ago

So with some more fiddling and realizing that ForwardDiff only works with AbstractArray inputs I was able to get the above working (not yet checked for correctness). I was able to get around this through a combination of RecursiveArrayTools and Flux.fmap. I also had to modify the loglikelihood expression to get rid of the .^2 expression, which appears to be related to https://github.com/FluxML/Zygote.jl/issues/405.

While I haven't checked the result for correctness, the solution itself isn't too ugly :).

I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it.

I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc.

using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    # broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
    #zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
    # works
    zs = (meanact .- action) ./ exp.(P.logstd)
    zs = zs .* zs

    ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2
    ll
end

flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)

function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
    ps = Flux.params(policy)

    i = 1
    dualpol = Flux.fmap(policy) do p
        if p in ps.params
            p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
            i += 1
        end
        p
    end
    dualps = params(dualpol)

    G = let feat=feat, act=act
        function (ps)
            gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
            flatgrad(gs, ps)
        end
    end

    ForwardDiff.partials.(G(dualps), 1)
end

function test_flathvp(T::DataType=Float32)
    Random.seed!(1)

    dobs, dact = 4, 2
    policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
    policy = Flux.paramtype(T, policy)

    v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
    feat = rand(T, 4)
    act = rand(T, 2)

    @time flat_hessian_vector_product(feat, act, policy, v)
end
YichengDWu commented 2 years ago

Can you try using Lux and ComponentArrays?