baggepinnen / FluxOptTools.jl

Use Optim to train Flux models and visualize loss landscapes
MIT License
59 stars 4 forks source link

mismatch `veclength` between `Zygote.gradient` and `Flux.params` #9

Closed Red-Portal closed 2 years ago

Red-Portal commented 2 years ago

Hi, it seems veclength has a bug where the length of the gradient and Flux.params yields different values. This triggers an internal assertion macro. Here's a reproducible example.

using Bijectors
using Flux
using FluxOptTools
using Optim
using StatsBase
using Zygote

    data = randn(2, 10)
    base_dist = MvNormal(zeros(2), ones(2))
    layers    = reduce(∘, [
        PlanarLayer(2)
        for i in 1:4
    ])
    flow = transformed(base_dist, layers)
    loss = () -> -sum(logpdf.(Ref(flow), eachcol(data)))
    Zygote.refresh()
    pars = Flux.params(flow)

    grads = Zygote.gradient(loss, pars)

    println(veclength(pars))
    println(veclength(grads))

   @assert veclength(pars) == veclength(grads)
Red-Portal commented 2 years ago

It seems like the gradient of flow contains a non-gradient.

IdDict{Any, Any}([0.643964081203799, -0.33972483245547835] => [-1541.1739416233916, 1191.157541019066], [0.14540928956058152, -0.08136481418860046] => [-3716.620948764811, -1141.2559573811584], [1.075215841523743] => [-10.879948609605961], [0.8546879143699161, -1.3575339819843644] => [350.3771058739833, 477.5608533609593], [0.044090629105113646, -0.6092739111255705] => [-4521.422906453032, 50.39511843818664], [0.7106450315871697, -1.2411071500031408] => [694.9667780305496, 193.06446353385192], [2.123618880999641] => [100.67083715804323], [-0.2502735230749665, -0.9787259534387902] => [-2904.2085561798467, -1326.8258198410397], [0.4789010235348985, 0.4631552734028355] => [15.338401162918993, -3689.760938565437], [-0.27106510950721024] => [158.51721073158237], [0.3718628458807646, 0.26452126966080663] => [-3217.2481365601184, 322.2996951793478], [-0.753452180737133] => [10.02347485306564], Base.RefValue{MultivariateTransformed{DiagNormal, Composed{NTuple{4, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}}}(MultivariateTransformed{DiagNormal, Composed{NTuple{4, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}}(
dist: DiagNormal(
dim: 2
μ: [0.0, 0.0]
Σ: [1.0 0.0; 0.0 1.0]
)

transform: Composed{NTuple{4, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}((PlanarLayer{Vector{Float64}, Vector{Float64}}([0.044090629105113646, -0.6092739111255705], [0.7106450315871697, -1.2411071500031408], [2.123618880999641]), PlanarLayer{Vector{Float64}, Vector{Float64}}([0.643964081203799, -0.33972483245547835], [-0.2502735230749665, -0.9787259534387902], [-0.27106510950721024]), PlanarLayer{Vector{Float64}, Vector{Float64}}([0.8546879143699161, -1.3575339819843644], [0.14540928956058152, -0.08136481418860046], [1.075215841523743]), PlanarLayer{Vector{Float64}, Vector{Float64}}([0.4789010235348985, 0.4631552734028355], [0.3718628458807646, 0.26452126966080663], [-0.753452180737133])))
)
) => Base.RefValue{Any}((x = nothing,)))

A quick fix would be to set

veclength(grads::Zygote.Grads) = sum(length, grads.params)

instead of the current

veclength(grads::Zygote.Grads) = sum(typeof(g[1]) !== GlobalRef ? length(g[1]) : 0 for g in grads.grads)

but not sure if this would be a robust solution.