baggepinnen / FluxOptTools.jl

Use Optim to train Flux models and visualize loss landscapes
MIT License
59 stars 4 forks source link

AssertionError: `length(v) == veclength(grads)` #10

Closed pakk-minidose closed 2 years ago

pakk-minidose commented 2 years ago

Hello, I have encountered the following issue. When trying the code below

using Optim, Flux, Zygote, FluxOptTools

m = Chain(Dense(100,10), Dense(10,1))
X = rand(100,1000)
y = rand(1000)'

λ = 1e-6

function loss()
    mse_loss = Flux.Losses.mse(m(X),y)
    ps = params(m)
    sum_weights = sum(sum(abs2,p) for p in ps)
    return mse_loss + λ*sum_weights
end

Zygote.refresh()
pars = Flux.params(m)
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)

res = Optim.optimize(Optim.only_fg!(fg!), p0, BFGS(), Optim.Options(iterations=10))

the last line of code produces AssertionError: length(v) == veclength(grads).

I have tracked down the issue to some non-params related entries appearing in the gradients:

julia> grad = Zygote.gradient(loss,pars)
Grads(...)

julia> grad.grads
IdDict{Any, Any} with 10 entries:
  :(Main.y)                                                                    => [0.00140496 0.00109027 … 0.000889871 0.00305764]
  IdDict{Any, Nothing}(Chain(Dense(100, 10), Dense(10, 1))=>nothing, (Chain(D… => IdDict{Any, Any}()
  :(Main.X)                                                                    => [8.04866e-5 6.24584e-5 … 5.09783e-5 0.000175164; 2.74849e-5 2.13286e-5 … 1.74083e-5 5.98158e-5; … ; -8.38364e-5 -6.50579e-5 … -5.31e-5 -0.000182454; …  :(Main.λ)                                                                    => 19.4517
  Buffer{Any, Vector{Any}}(Any[Float32[-0.11995 0.168377 … -0.205164 -0.00029… => Any[]
  Float32[0.657628 -0.368154 … -0.144815 0.356743]                             => [-0.795976 -0.294902 … -0.406905 1.27087]
  Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]                    => Float32[-1.28669, 0.720318, -0.181089, -0.832897, 0.95751, -1.02594, -0.55508, -0.778771, 0.28334, -0.697992]
  Float32[-0.11995 0.168377 … -0.205164 -0.000299127; -0.019391 -0.063409 … 0… => [-0.632358 -0.650336 … -0.619511 -0.664975; 0.354007 0.364072 … 0.346815 0.372267; … ; 0.139251 0.14321 … 0.136421 0.146433; -0.343035 -0.352789 … -0…  Float32[0.0]                                                                 => Float32[-1.95657]
  IdDict{Any, Nothing}(Float32[0.657628 -0.368154 … -0.144815 0.356743]=>noth… => IdDict{Any, Any}()

There are two dictionaries and one buffer, which are not GlobalRef type, yet these do not correspond to the model parameters. These trigger the assertion https://github.com/baggepinnen/FluxOptTools.jl/blob/2a1d2210057efce7b380308a4ad2c4a11b04c174/src/FluxOptTools.jl#L17

In this case, the corresponding dictionary entries are empty, but this may not be a general rule to detect this sort of entries.

The issue seems closely related to https://github.com/baggepinnen/FluxOptTools.jl/issues/9 . I think we need a function which cherry picks keys of the grad.grads dictionary corresponding to the params of the model, which would then solve both the issues.

Thank you for any help.

baggepinnen commented 2 years ago

Hello and thank you for your issue!

I have unfortunately not used Flux or Zygote for quite a while and will likely have a hard time finding time to fix issues with this package at the moment. If you manage to solve it, feel free to submit a PR and I'll merge it, otherwise might take a while.

Best regards, Fredrik

pakk-minidose commented 2 years ago

Hello, Thank you for your reply. I will try to find some time to fix the issue.

Best regards, Dominik