Closed ablaom closed 3 months ago
This PR is above average in complexity. This means a review is particularly important but it's also going to be more work than usual. @pat-alt Do you have any interest and time to review over the next 3 weeks, say?
My apologies in advance for slightly messy commits. I temporarily lost access to a GPU for local testing and was shooting blind for a while.
@pat-alt Sorry to re-ping, but I'm not sure who else to ask for a review here. @tiemvanderdeure Would you consider reviewing?
If possible, hoping for a merge in the next 3 weeks. Even a cursory look, would be much appreciated!
Hi! I'll try to have a look as soon as I can (probably on the weekend or next week).
If it is helpful I can also give reviewing this PR a go. Probably won't have time the next few days but early next week should be feasible.
Let's see if @pat-alt is able to finds some time.
Thanks @pat-alt for your review. Much appreciated. I've made a few tweaks in response to the comments.
Thanks @ablaom, will have another look today.
@Rockdeldiablo for reference. See in particular the redefinition of fit!
and my comments above.
@pat-alt How are we doing? Happy with the changes?
@pat-alt How are we doing? Happy with the changes?
Sorry, yes, I missed the thumbs up :) Thanks for clarifying!
1bd58dd adds deprecations for the fit!
and train!
methods. I think this is useful for developers who have used the old API in their own packages to extend MLJFlux, as we have done here, for example. Not sure who else has done something like this (and also unsure if this is the intended way to extend MLJFlux), but in any case adding these deprecations should help.
Thanks @pat-alt for your review. 🙏🏾
This PR combines a number of changes, which for technical reasons could not be easily split up. The most important change, anticipating a Flux 0.15 breakage, is the switch to explicit differentiation; so this PR replaces #230. Shout out to @pat-alt for outlining a solution there.
Closes #221.
To do:
[x] Replace implicit style parameter updates with explicit style parameter updates, in line with planned Zygote/Flux deprecations.
[x] Refactor code to use optimisers from Optimisers.jl with
setup/update
pattern in place ofupdate!
pattern. Also, rename private methodstrain!
->train_epoch
andfit!
->train
to reflect new non-mutating behaviour. This possibly breaks some "custom" models that have chosen to overload these technically private methods.[x] (RNG changes.) Change the default value of the model field
rng
fromRandom.GLOBAL_RNG
toRandom.default_rng()
. Change the seeded RNG, obtained by specifying an integer value forrng
, fromMersenneTwister
toXoshiro
.[x] Update the
Short
builder so that therng
argument ofbuild(::Short, rng, ...)
is passed on to theDropout
layer, as these layers now support this on a GPU, at least forrng=Random.default_rng()
.[x] Change the implementation of L1/L2 regularization from explicit loss penalization to weight/sign decay (internally chained with the user-specified optimiser). Breaking: The losses reported in the history will no longer be penalized, because the penalty is not explicitly computed.
[ ] Update documentation to reflect use of
Optimisers.jl
optimisers, instead of Flux.jl native optimisers. And on changes to therng
defaults. Waiting on #252