FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

Utility for walking a tree (e.g. gradients) w.r.t. a model #143

Open darsnack opened 1 year ago

darsnack commented 1 year ago

Motivation and description

Using trainable, we can walk a model and only apply a function to trainable parameters. But the gradient from Zygote is a named tuple without this information.

Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient outside of the update context.

Possible Implementation

We can include a maptrainable(f, model, [gradient]) (or better name) function that maps a function w.r.t. the trainable parameters of model.

darsnack commented 1 year ago

Ideally, I think the implementation would underly update (i.e. update is maptrainable with f specialized to call apply).

ToucheSir commented 1 year ago

Related: https://github.com/FluxML/Optimisers.jl/pull/57. We have proposals for map and reduce, but does it make sense to try for a mapreduce?

darsnack commented 1 year ago

Agreed, with the ability to add more trees to call as described above.

mcabbott commented 1 year ago

If another tree like gradient is passed, then f is applied to the leaves of gradient (i.e. approximately fmap(TrainableWalk(f), gradient, model) using the last argument to filter the walk).

I think the most obvious t_mapreduce(f, r, model, grads) would always call f(x, dx), but take trainability from the model. The present fmap(f, xs, ys) always calls f(x,y):

julia> fmap(println, (a=1, b=2), (a="!", b="?"))
1!
2?
(a = nothing, b = nothing)

julia> sh = [1.0]; fmap(println, (a=sh, b=sh), (a="!", b="?"))
[1.0]!
(a = nothing, b = nothing)

The tricky bit as usual will be shared parameters. Here fmap simply ignores y belonging to a shared x. This fmap(f, xs, ys) is a half-baked feature, I think update! was the original target but it's not actually right for that.

The walk done by Optimisers.update! instead adds distinct dx belonging to shared x before calling apply!. I wonder how often that would be correct, e.g. for the gradient norm example it probably would be. To write update! (ignoring its return) you would need t_mapreduce(f, Returns(nothing), model, grads, state_tree) where we add dx but not state?

julia> Optimisers.setup(Momentum(), (a=sh, b=sh))
(a = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), b = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]))

julia> ans.a === ans.b
true

This all always feels like we have slightly the wrong abstractions.

ericphanson commented 1 year ago

For the norm use-case, another thing that would be handy is if I could destructure the gradient to flatten it, but only keep the trainable params as governed by the model. Then I can just take a norm directly on the flat vector.

Or maybe a more composable thing would be if I could walk the model & gradient simultaenously, and map non-trainable gradients to nothing, returning an updated gradient that only has non-nothing entries for trainable params. Then I could do whatever I wanted with that (walk it again with fmap, flatten it with destructure, etc).

ToucheSir commented 1 year ago

A simpler version of this came up in conversation over RL models on Slack today. The current incantation for updating one model's parameters based on the moving average of another model's is:

for (t, p) in zip(Flux.params(target), Flux.params(policy))
        t .= (1 - tau) .* t .+ tau .* p
end

To which I proposed:

Functors.fmap(m_target, m_policy; walk = Optimisers._Trainable_biwalk()) do t, p
  t = (1 - tau) .* t .+ tau .* p
end

It should take no time to package up the latter as a mapparams function on our side. The questions are, where should it live (Flux or Optimisers) and what should be it called (e.g. maptrainable instead)?