FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

"Optimisers.jl does not at present handle tied weights, sorry." #97

Closed gdalle closed 1 year ago

gdalle commented 2 years ago

Hi, and thanks for the hard work on Flux! Do we have an idea what is causing this error?

mcabbott commented 2 years ago

One way is #42, which you can try via ] add https://github.com/mcabbott/Optimisers.jl#duplicated -- would be interesting to see whether it works well in the wild.

Brian had an idea for a less complicated implementation, so #42 may or may not actually get merged.

darsnack commented 2 years ago

Do we have an idea what is causing this error?

In case this was a question for clarity on what the error means: you have a model where two parameters share the same underlying array (e.g. auto-encoders). We currently don't have a good way of accumulating the gradients in all these cases reliably. So, we error instead.

gdalle commented 2 years ago

Thanks for the answers! I actually understood the error message, but since tied weights obviously work in Flux.jl itself, I was simply wondering why it didn't for this library. In my case, the error was generated by a ReshapeLayer from Lux.jl, which is pretty annoying for such a simple block. But as always, the solution is way more complicated to implement than the problem is to state.

ToucheSir commented 2 years ago

ReshapeLayer should not be causing an error like this since it reshapes an output on the fly (vs a parameter offline, which is what this is supposed to catch). Do you have a MWE for that?

gdalle commented 2 years ago

Here's an MWE on Julia 1.7.3. It calls the solvers through Optimization.jl because that's what I was doing when I discovered the bug.

pkg> st
      Status `~/.../Project.toml`
  [b0b7db55] ComponentArrays v0.12.2
  [f6369f11] ForwardDiff v0.10.30
  [b2108857] Lux v0.4.7
  [7f7a1694] Optimization v3.7.1
  [253f991c] OptimizationFlux v0.1.0
  [42dfb2eb] OptimizationOptimisers v0.1.0
  [e88e6eb3] Zygote v0.6.41
  [9a3f8284] Random

julia> using ComponentArrays, ForwardDiff, Lux, Optimization, Random, Zygote

julia> import OptimizationFlux, OptimizationOptimisers

julia> rng = Random.default_rng(); Random.seed!(rng, 0);

julia> model = Chain(Dense(1, 9), ReshapeLayer((3, 3)), sum);

julia> ps, st = Lux.setup(rng, model);

julia> ps = ComponentArray(ps);

julia> x = ones(1, 1);

julia> f = OptimizationFunction((p, _) -> model(x, p, st), Optimization.AutoForwardDiff());

julia> prob = OptimizationProblem(f, ps, (;));

julia> solve(prob, OptimizationFlux.Adam(), maxiters=10);

julia> solve(prob, OptimizationOptimisers.Adam(), maxiters=10);
ERROR: ArgumentError: Optimisers.jl does not at present handle tied weights, sorry.

I just realized that simply updating Lux.jl to v0.4.8 fixes the issue, so it may be a bug on the Lux.jl side of things.

ToucheSir commented 2 years ago

May be https://github.com/avik-pal/Lux.jl/pull/74, I remember something similar from the issue which spawned that.

ToucheSir commented 2 years ago

And see #100 for the aforementioned alternate spin on handling tied weights.