Closed gdalle closed 1 year ago
One way is #42, which you can try via ] add https://github.com/mcabbott/Optimisers.jl#duplicated
-- would be interesting to see whether it works well in the wild.
Brian had an idea for a less complicated implementation, so #42 may or may not actually get merged.
Do we have an idea what is causing this error?
In case this was a question for clarity on what the error means: you have a model where two parameters share the same underlying array (e.g. auto-encoders). We currently don't have a good way of accumulating the gradients in all these cases reliably. So, we error instead.
Thanks for the answers! I actually understood the error message, but since tied weights obviously work in Flux.jl itself, I was simply wondering why it didn't for this library.
In my case, the error was generated by a ReshapeLayer
from Lux.jl, which is pretty annoying for such a simple block. But as always, the solution is way more complicated to implement than the problem is to state.
ReshapeLayer
should not be causing an error like this since it reshapes an output on the fly (vs a parameter offline, which is what this is supposed to catch). Do you have a MWE for that?
Here's an MWE on Julia 1.7.3. It calls the solvers through Optimization.jl because that's what I was doing when I discovered the bug.
pkg> st
Status `~/.../Project.toml`
[b0b7db55] ComponentArrays v0.12.2
[f6369f11] ForwardDiff v0.10.30
[b2108857] Lux v0.4.7
[7f7a1694] Optimization v3.7.1
[253f991c] OptimizationFlux v0.1.0
[42dfb2eb] OptimizationOptimisers v0.1.0
[e88e6eb3] Zygote v0.6.41
[9a3f8284] Random
julia> using ComponentArrays, ForwardDiff, Lux, Optimization, Random, Zygote
julia> import OptimizationFlux, OptimizationOptimisers
julia> rng = Random.default_rng(); Random.seed!(rng, 0);
julia> model = Chain(Dense(1, 9), ReshapeLayer((3, 3)), sum);
julia> ps, st = Lux.setup(rng, model);
julia> ps = ComponentArray(ps);
julia> x = ones(1, 1);
julia> f = OptimizationFunction((p, _) -> model(x, p, st), Optimization.AutoForwardDiff());
julia> prob = OptimizationProblem(f, ps, (;));
julia> solve(prob, OptimizationFlux.Adam(), maxiters=10);
julia> solve(prob, OptimizationOptimisers.Adam(), maxiters=10);
ERROR: ArgumentError: Optimisers.jl does not at present handle tied weights, sorry.
I just realized that simply updating Lux.jl to v0.4.8 fixes the issue, so it may be a bug on the Lux.jl side of things.
May be https://github.com/avik-pal/Lux.jl/pull/74, I remember something similar from the issue which spawned that.
And see #100 for the aforementioned alternate spin on handling tied weights.
Hi, and thanks for the hard work on Flux! Do we have an idea what is causing this error?