Closed pat-alt closed 1 year ago
Since the methods are currently indicated as "private", I'm not sure this approach is aligned with the design you have in mind.
I've no plans to make these private methods public along some different lines. So if this makes sense for your use-case, then I support it!
It's very encouraging to see those abstractions I introduced in mlj_model_interface.jl
actually get used for a new model.
I'm curious, skimming your application to energy models, you have two build
methods. Where does the build
dispatched on a "model" get used?
I think all that's needed are some doc strings.
Thanks for spending so much time getting into the weeds of this package.
Awesome, thanks (also for adding me as a collaborator)!
As for the below:
I'm curious, skimming your application to energy models, you have two build methods. Where does the build dispatched on a "model" get used?
I'm overloading the build
method here. Instead of only building the chain (as for the NeuralNetworkClassifier
), in this case I also use the build function to instantiate the JointEnergyModel
which is the base type for JEMs in the package.
There's probably a better/cleaner way to do this, but I needed a quick solution for a paper I'm working on at the moment. Will look at improving the JEM package of the summer.
Have updated the docstrings now 😃
Patch coverage: 100.00
% and project coverage change: +0.04
:tada:
Comparison is base (
452c09d
) 92.73% compared to head (b0d70cf
) 92.78%.
:mega: This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Suggestion
This adds a minor change to two important private methods:
train!
andfit!
. Dispatching these on themodel::MLJFlux.MLJFluxModel
would allow users/developers to implement custom training loops without changing the default behaviour for default models (NeuralNetworkRegressor
,NeuralNetworkClassiefier
, ...).Example
I'm using this branch in the development version of a new package I'm working on:
JointEnergyModels.jl
. That package is tailored toFlux.jl
but the minimal changes suggested here have made it possible to make it compatible withMLJFlux.jl
without much hassle:JointEnergyClassifier
constructor that subtypesMLJFlux.MLJFluxProbabilistic
and mostly inherits methods and fields from theNeuralNetworkClassifier
through composition.fit!
method forJointEnergyClassifier
.Details can be found in
src/mlj_flux.jl
.Similarly this could help with adding support for adversarial training, for example.
Alternatives?
Since the methods are currently indicated as "private", I'm not sure this approach is aligned with the design you have in mind. If you'd rather not go down this route, do you have any suggestions for alternatives?
PR Checklist