FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

WeightDecay for L1 norm #159

Closed mcabbott closed 7 months ago

mcabbott commented 1 year ago

As I learned here https://github.com/FluxML/MLJFlux.jl/issues/221#issuecomment-1707604760 , since the gradient of L1 norm is even simpler than the gradient of L2 norm it can, obviously, be implemented as an optimisation rule.

This quick PR adds it to the same WeightDecay struct. Below is a check that this does what you expect.

```julia using Flux: Flux, Dense, gradient, state using Optimisers using Optimisers: setup, update input = [1,2] model = Dense([1 -2; 3 -4.0]) grads = Flux.gradient(model) do m result = m(input) sum(result) end # Check L2 norm via WeightDecay (nothing new!) pen_l2(x::AbstractArray) = sum(abs2, x)/2 grads_L2 = Flux.gradient(model) do m result = m(input) penalty = sum(pen_l2, Flux.params(m)) sum(result) + 0.42 * penalty end update( setup(Descent(0.1), model), model, grads_L2[1])[2] |> Flux.state update( setup(OptimiserChain(WeightDecay(0.42), Descent(0.1)), model), model, grads[1])[2] |> Flux.state # Do exactly the same thing for L1 (needs this PR) pen_l1(x::AbstractArray) = sum(abs, x) grads_L1 = Flux.gradient(model) do m result = m(input) penalty = sum(pen_l1, Flux.params(m)) sum(result) + 0.42 * penalty end update( setup(Descent(0.1), model), model, grads_L1[1])[2] |> Flux.state update( setup(OptimiserChain(WeightDecay(0.0, 0.42), Descent(0.1)), model), model, grads[1])[2] |> Flux.state # Both give (weight = [0.858 -2.158; 2.858 -4.158], bias = [-0.1, -0.1], σ = ()) ```

PR Checklist

darsnack commented 1 year ago

An alternative API is to add SignedDecay (or something) if we find WeightDecay(0.0, 0.004) too weird.

ToucheSir commented 1 year ago

I thought about that too, but this seems more straightforward if one wants to combine L1 and L2. We don't currently have a Parallel-esque rule which feeds the same gradient into two different rules, though now that I say it such a composite rule could be a nice addition.

mcabbott commented 1 year ago

Yes I wondered about an independent rule, but then thought precisely that you may want a bit of L1 and a bit of L2. And also, perhaps, that if you know about this trick for L2, then this proximity may help you discover the similar trick for L1.

I gave it the next unused greek letter. It's sort-of neat that each different rule you may wish to chain uses a different field name, as adjust!(..., zeta=0.1) etc. never modifies two unrelated things.

darsnack commented 1 year ago

For what it's worth, I'm okay with a single rule. But just to push the other side bit more, you don't need a Parallel-esque construct for these rules to compose. OptimiserChain(WeightDecay(0.004), SignedDecay(0.004), Descent(0.1)) works just fine (since it depends on x not dx).

ToucheSir commented 1 year ago

Ah you're right, I got my wires crossed there.

FWIW, the AdamW paper uses λ for the weight decay term, which PyTorch borrows for its optimizer documentation but does not use in any API.

darsnack commented 1 year ago

Another option is to have SignedDecay(zeta) = WeightDecay(0, zeta). I'm okay with all options, just throwing things out for consideration.

ablaom commented 1 year ago

Thanks for considering this contribution @mcabbott.

Another convention, adopted in elastic net and elsewhere in statistics is to have an overall lambda parameter and an L1/L2 mixture parameter alpha. This is what we do in MLJFlux.

https://github.com/FluxML/MLJFlux.jl/blob/b449d80d1d5606298bae0ded1992ee35c5c099c0/src/penalizers.jl#L11

But I don't have a strong opinion.

mcabbott commented 1 year ago

Ah that is a nice idea.

It sounds like lambda is more standard. I don't know where we got gamma, possibly I just invented something other than Flux's .wd:

https://github.com/FluxML/Flux.jl/blob/95737ffc9aa989f31d5fecd9a887a9c25f4fd865/src/optimise/optimisers.jl#L690-L692

It only matters because of adjust!, but I guess we can add a deprecation.

ablaom commented 1 year ago

Yes, but I have also seen the roles of lambda and alpha reversed :-(

mcabbott commented 1 year ago

I wish I was surprised...

Now changed to lambda alpha. This seems fairly natural to have as one struct not two.

Not easily accessible from Flux, but shouldn't break anything:

julia> Flux.setup(Flux.WeightDecay(0.1), [1,2.0]) |> dump
Optimisers.Leaf{WeightDecay, Nothing}
  rule: WeightDecay
    lambda: Float64 0.1
    alpha: Float64 0.0
  state: Nothing nothing
  frozen: Bool false
ablaom commented 7 months ago

@mcabbott Do you have some time to push this along? The project to update MLJFlux to use explicit parameters is waiting on this.

mcabbott commented 7 months ago

I had a go locally & will try to find the branch

mcabbott commented 7 months ago

Ok dbcea29 pushes what I had locally, way back when. It leaves WeightDecay alone, and makes a new struct for the combined L1 and L2 story. I called this NormReg although perhaps there's a better name.

Is this a good design? We could instead have a new struct which does only L1. And then (if we want to support a mixture) have some function which returns a chain of L1 and L2, using existing structs. Maybe that would be better.

Edit: And f70aa9c changes to a separate SignDecay struct for L1 alone. No function for a combination. Maybe that's the minimal thing.

Maybe they should not have the same field name lambda, ideas for what might be better?

ToucheSir commented 7 months ago

PyTorch wasn't very helpful as inspiration, but optax uses the term "decay rate" in their implementation of weight decay. A little verbose but pretty clear.

Alternatively, the sklearn regression models call this L1/L2 coefficient alpha. The ElasticNet page specifically refers to it as a "penalty (term)", which is another idea for a plain English word.

ablaom commented 7 months ago

The separate SignDecay option, as currently implemented here, would suit me fine. In this way, I can confidently use the two decays without looking up documentation to sort out the notation and convention about 1, or 1/2. (In Elastic net I have seen the roles of alpha and lambda reversed in some implementations.)

mcabbott commented 7 months ago

Another argument against having a mixture parameter: In most practical use, knowing that λ = 1e-3 is a useful amount of L2 for your problem does not imply that this is the right amount of L1... you are going to have to search. In which case just changing the mixture / angle parameter isn't really better than changing some other κ instead.

Last commits change the name of the L1 penalty coefficient to "kappa", because it's next door and not used in this package yet (hence adjust(st, kappa=0.1) will hit exactly one thing).

ToucheSir commented 7 months ago

If you'll allow me to bikeshed names one more time: I feel like we should not be pulling out greek characters that have not been used in the literature before, even if they are represented as English words instead of the original symbols. Are there no descriptive terms we can use instead of kappa (and maybe lambda too)?

Otherwise LGTM.

ablaom commented 7 months ago

I think they could have the same name. They're both regularization parameters in separate structs. I think lambda (or its unicode equivalent) is pretty standard for a generic reg. param.

In the same vein that eta is used for learning rate in all the variations of Optimiser's grad descent.

mcabbott commented 7 months ago

Good point about eta being used everywhere, maybe just re-using lambda is best.

Maybe this is done? CI on julia > 1.6 might be fixed later by https://github.com/FluxML/Optimisers.jl/pull/166

mcabbott commented 7 months ago

If either of you clicks approve I can merge this, and then rebase #160