Closed ablaom closed 1 month ago
Relatedly, it would be nice if the MLJFlux models listed here https://github.com/FluxML/model-zoo#examples-elsewhere could be updated to use latest Flux, and avoid implicit gradients.
Examples of similar upgrades: https://github.com/FluxML/model-zoo/issues?q=is%3Aclosed+label%3Aupdate+explicit
In the end, Flux 0.14 did not drop support for implicit gradients, but 0.15 should.
@pat-alt Would you have any time and interest in addressing this issue?
That actually syncs well with some of my other outstanding issues and I think I'll have to address this very same thing in CounterfactualExplanations.jl soon. So yes, please feel free to assign to this one to me and I'll look at it in the coming weeks 👍
I have added a draft for this with very minor changes here #230:
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
parameters = Flux.params(chain)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
pen = penalty(parameters) / n_batches
loss(yhat, y[i]) + pen
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end
Currently, the following test fails:
[ Info: regularization has an effect:
[ Info: acceleration = CPU1{Nothing}(nothing)
regularization has an effect (typename(CPU1)): Test Failed at /Users/patrickaltmeyer/code/MLJFlux.jl/test/integration.jl:25
Expression: !(loss2 ≈ loss3)
Evaluated: !(0.8354643267207931 ≈ 0.8354643267207931)
I'm not quite sure what's happening. @mcabbott can you spot anything obviously wrong this?
That's because the regularization term is still using implicit params. Something like https://github.com/FluxML/Flux.jl/issues/2040#issuecomment-1535535892 will be needed for explicit params.
parameters = Flux.params(chain)
outside the gradient
context will only work in the implicit style -- changing the explicit local m
will not change pen
. (Edit -- as ToucheSir says, while I was typing!)
What is penalty
? For L2 it will be better to use WeightDecay
like this: http://fluxml.ai/Flux.jl/stable/training/training/#Regularisation
Thanks both!
What is
penalty
? For L2 it will be better to useWeightDecay
like this: http://fluxml.ai/Flux.jl/stable/training/training/#Regularisation
Currently, penalty functions are explicitly defined callable objects in MLJFlux (see here). I saw the note on WeightDecay
in the Flux docs and was wondering if it's worth changing that.
In any case, I can't really get either of the approaches you suggest to work in this particular case, so we may indeed want to rethink the implementation of the penalty functions, for example by using WeightDecay
instead. Will require a little extra work, but should be doable. @ablaom what do you think?
I can't really get either of the approaches you suggest to work in this particular case
Can you elaborate? I'm not sure I understand why/how they wouldn't work.
Sure!
Moving the params
call inside as follows
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
pen = penalty(Flux.params(m)) / n_batches
loss(yhat, y[i]) + pen
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end
the tests just seem to get stuck at some point. I may try and commit this now, but at least locally on my machine things get stuck.
Alternatively, using the approach in https://github.com/FluxML/Flux.jl/issues/2040#issuecomment-1535535892 as follows
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
l = loss(yhat, y[i])
reg = Functors.fmap(penalty, m; exclude=Flux.trainable)
l + reg / n_batches
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end
I get the following error:
[ Info: acceleration = CPU1{Nothing}(nothing)
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(5 => 15) # 90 parameters
│ summary(x) = "5×20 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/n3cOc/src/layers/stateless.jl:60
fit! and dropout (typename(CPU1)): Error During Test at /Users/patrickaltmeyer/code/MLJFlux.jl/test/test_utils.jl:38
Got exception outside of a @test
TypeError: non-boolean (NamedTuple{(:layers,), Tuple{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}) used in boolean context
Perhaps it has to do with the fact that the penalizers aren't Functors
?
Yeah I wouldn't try the first version you have there, was referring to the second one or @mcabbott's suggestion about moving things to the optimization step.
I get the following error: ...
Pretty sure that's due to a typo in the original example code snippet. See https://github.com/FluxML/Flux.jl/issues/2040#issuecomment-1660776696
hmm in that case I get the following error: MethodError: no method matching Dense(::Float32, ::Float32, ::typeof(identity))
. Any ideas?
Thanks @pat-alt for this work!
In any case, I can't really get either of the approaches you suggest to work in this particular case, so we may indeed want to rethink the implementation of the penalty functions, for example by using WeightDecay instead. Will require a little extra work, but should be doable. @ablaom what do you think?
WeightDecay
only provides a mechanism for L2 regularisation, but the current implementation provides for a combination of both L1 regularisation (good for feature selection) and L2 regularisation. It seems a pity to drop support of a feature to accomodate the new explicit syntax.
I don't know what the source of your current issue is.
@pat-alt I don't think your use of Functors.fmap
is valid here. The penalty
function takes a tuple of matrices, as returned by Flux.params(chain)
, and returns a single aggregate number.
Your first suggestion (with params
) actually works but is 3600 times slower than the implicit style code on the dev
branch, when tested on a small model / dataset.
@ToucheSir To implement mixed L1/L2 penalties (not just L2 ones) I don't really see how to avoid the params
in the withgradient
block. (And this is after all a suggestion in the Flux documentation - second code block here). Am I to conclude that explicit-Zygote style AD is just no good on this problem?
To implement mixed L1/L2 penalties (not just L2 ones) I don't really see how to avoid the
params
in thewithgradient
block. (And this is after all a suggestion in the Flux documentation - second code block here). Am I to conclude that explicit-Zygote style AD is just no good on this problem?
It's arguably better, but it requires some helper functionality that isn't currently nicely packaged up in a library. https://github.com/FluxML/Optimisers.jl/pull/57 is one example of how to do this and how we're thinking about packaging it up going forwards, but the problem with general solutions is that they take time. For this work, you may be better served by implementing a similar but more constrained version on top of Functors.jl and Optimisers.jl which only includes as much as MLJFlux needs for regularization. If you do, feel free to ping me for input.
@ToucheSir Thanks for the prompt response and offer of help.
So, with the apparatus you describe (Functors.jl, etc ) what code replaces the following to avoid the params
call, working for a generic Flux model, chain
, and so that differentiating chain -> penalty
is free of issues?
# function to return penalty on an array:
f(A) = 0.01*sum(abs2, A) + 0.02*sum(abs, A)
f(ones(2,3))
# 0.6000000000000001
chain = Chain(Dense(3=>5), Dense(5=>1, relu))
penalty = sum(f.(Flux.params(chain)))
Or if you prefer, how should the regularisation example in the Flux documentation be re-written (without the weight-decay trick , which does not work for L1 penalty)?
f(A) = ...
penalty = mytotal(f, chain)
Where mytotal
is a simplified form or direct copy of Optimisers.total
as I mentioned earlier.
...(without the weight-decay trick , which does not work for L1 penalty)?
Side note, but I remembered looking into this a few months back and coming across https://stackoverflow.com/questions/42704283/l1-l2-regularization-in-pytorch/66630301#66630301, which suggests that L1 could be implemented using a similar trick. Whether that would be compatible with MLJFlux's API I'm not sure, but we could consider adding it to Optimisers.jl.
Thanks for the help @ToucheSir . Unfortunately, Optimisers.total
is not working for me. I've tried some variations on that approach but without any luck.
I suggest we wait on the WeightDecay extension referenced above and switch that approach, which is likely more performant anyhow.
It seems the style used here is being deprecated and won't work with Flux 0.14: https://github.com/FluxML/MLJFlux.jl/blob/452c09d7dc13914f4057c661448bc310c9362d3d/src/core.jl#L37
edit After discussion below, I suggest we wait on
and refactor to use a optimiser-based solution to weight regularisation, which will avoid current limitations of explicit differentiation outlined in the discussion. Note, this will likely mean the reported
training_loss
must change, as it will no longer include the weight penalty. So this will be breaking.