FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
145 stars 17 forks source link

Dispatch fit! and train! on model for greater extensibility #222

Closed pat-alt closed 1 year ago

pat-alt commented 1 year ago

Suggestion

This adds a minor change to two important private methods: train! and fit!. Dispatching these on the model::MLJFlux.MLJFluxModel would allow users/developers to implement custom training loops without changing the default behaviour for default models (NeuralNetworkRegressor, NeuralNetworkClassiefier, ...).

Example

I'm using this branch in the development version of a new package I'm working on: JointEnergyModels.jl. That package is tailored to Flux.jl but the minimal changes suggested here have made it possible to make it compatible with MLJFlux.jl without much hassle:

Details can be found in src/mlj_flux.jl.

Similarly this could help with adding support for adversarial training, for example.

Alternatives?

Since the methods are currently indicated as "private", I'm not sure this approach is aligned with the design you have in mind. If you'd rather not go down this route, do you have any suggestions for alternatives?

PR Checklist

ablaom commented 1 year ago

Since the methods are currently indicated as "private", I'm not sure this approach is aligned with the design you have in mind.

I've no plans to make these private methods public along some different lines. So if this makes sense for your use-case, then I support it!

It's very encouraging to see those abstractions I introduced in mlj_model_interface.jl actually get used for a new model.

I'm curious, skimming your application to energy models, you have two build methods. Where does the build dispatched on a "model" get used?

I think all that's needed are some doc strings.

Thanks for spending so much time getting into the weeds of this package.

pat-alt commented 1 year ago

Awesome, thanks (also for adding me as a collaborator)!

As for the below:

I'm curious, skimming your application to energy models, you have two build methods. Where does the build dispatched on a "model" get used?

I'm overloading the build method here. Instead of only building the chain (as for the NeuralNetworkClassifier), in this case I also use the build function to instantiate the JointEnergyModel which is the base type for JEMs in the package.

There's probably a better/cleaner way to do this, but I needed a quick solution for a paper I'm working on at the moment. Will look at improving the JEM package of the summer.

pat-alt commented 1 year ago

Have updated the docstrings now 😃

codecov-commenter commented 1 year ago

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.04 :tada:

Comparison is base (452c09d) 92.73% compared to head (b0d70cf) 92.78%.

:mega: This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files ```diff @@ Coverage Diff @@ ## dev #222 +/- ## ========================================== + Coverage 92.73% 92.78% +0.04% ========================================== Files 11 11 Lines 303 305 +2 ========================================== + Hits 281 283 +2 Misses 22 22 ``` | [Impacted Files](https://codecov.io/gh/FluxML/MLJFlux.jl/pull/222?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux) | Coverage Δ | | |---|---|---| | [src/core.jl](https://codecov.io/gh/FluxML/MLJFlux.jl/pull/222?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux#diff-c3JjL2NvcmUuamw=) | `94.87% <100.00%> (+0.13%)` | :arrow_up: | | [src/mlj\_model\_interface.jl](https://codecov.io/gh/FluxML/MLJFlux.jl/pull/222?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux#diff-c3JjL21sal9tb2RlbF9pbnRlcmZhY2Uuamw=) | `94.20% <100.00%> (ø)` | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.