FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
74 stars 22 forks source link

Transparent handling of tied weights #100

Closed ToucheSir closed 2 years ago

ToucheSir commented 2 years ago

This makes Leaf a mutable type so that tied weights are represented by the same leaf instance.

Although only mutable array types are automatically detected as tied, one can also tie immutable parameters by manually creating shared Leafs.

The test suite is practically the same as #42, with some slight modifications since there is no equivalent to Tied in this PR.

ToucheSir commented 2 years ago

Doctests appear to be picking up changes on master that aren't present on this branch, is that expected? I can't tweak the test because it doesn't exist here!

mcabbott commented 2 years ago

Here's an evil case for shared parameters:

mutable struct MutTwo; x; y; end
Functors.@functor MutTwo

tmp = MutTwo([1.0], [2.0])
model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y))
state = Optimisers.setup(Momentum(), model)

model.a === model.b
model.a !== model.c  # fields are identified, but struct is not

state.a.x === state.b.x
state.a === state.b
state.a === state.c  # unavoidable, but means we can't use Leaf ID alone?

mgrad = (a=(x=[1.], y=[10.]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30]))
state2, model2 = Optimisers.update(state, model, mgrad)
model2.a === model2.b
model2.a !== model2.c 

The state of all 3 components is (x=Leaf(...), y=Leaf(...)). A cache which is IdDict{Leaf,Any} can't identify the two structs. A cache which also stores higher levels of the state tree will instead identify all three structs.

One answer here is to store tuples (x, Leaf(...)). Then identifying the Leafs can be used as a trick to tie StaticArray parameters. But cannot tie Array parameters (which aren't already tied by ===).

ToucheSir commented 2 years ago

I don't think we ever guaranteed model.a !== model.c => state.a !== state.c. model2.a !== model2.c after an update seems more like a bug rather than an intrinsic limitation? If x and y are === for all 3 components, then this should just work. I had to read the comments in #106 for context. We weren't preserving identity at higher levels before, so I think that is orthogonal to the issue of tied leaves. It would be nice if we could though, which I see has been done there.