JuliaTrustworthyAI / LaplaceRedux.jl

Effortless Bayesian Deep Learning through Laplace Approximation for Flux.jl neural networks.
https://www.taija.org/LaplaceRedux.jl/
MIT License
39 stars 3 forks source link

39 refactor mljflux interface #92

Closed pasq-cat closed 2 months ago

pasq-cat commented 4 months ago

i have tried to use the new @mlj_model macro. there are some choices that i am not sure if they are right. First, the @mlj_model macro allows to define constraints directly in the struct, so there is no need for a clean! function, but if you prefer the older method i can go back.

Second, the mljmodelinterface quick guide tells that in the structs there need to be only hyperparameters, so i removed the builder field and now that chain has to be passed directly to fit!. It seems also logical since one use may want to experiment with different flux models without having to redefine the hyperparameters tied to the Laplace wrapper. However i am not sure that adding an argument to the fit! function is the correct choice since all the examples models do not do it.

third, i removed the two shape and build functions since they are at most a generic utility for the user that is not necessarily connected to the Laplaceredux package.

4) i have some doubt over the output of the predict function in the laplaceclassification case. In the regression case i picked the mean and the variance provided by laplace and used them to output a guassian distribution with Distributions.jl, but in the classification case MLJInterface says it has to be a UnivariateFinite element but the example provided direct to a broken link. So i left as an output the pseudoprobabilities of the classes.

It works (at least on my pc....), but i am not sure if it respect what MLJ wants and why these automatic checks complains so much. Is it because i didn't add the Project.toml and manifest.toml files?

pasq-cat commented 4 months ago

@pat-alt i will not work on this anymore without some kind of indication because i am confused and tired. i tried to change it back as it is was before (and adding the required fields as written here) https://github.com/FluxML/MLJFlux.jl/tree/01ad08ebc664f16d9509171866685c14d7bd6e99 but it doesn't work.

pat-alt commented 4 months ago

@pat-alt i will not work on this anymore without some kind of indication because i am confused and tired. i tried to change it back as it is was before (and adding the required fields as written here) https://github.com/FluxML/MLJFlux.jl/tree/01ad08ebc664f16d9509171866685c14d7bd6e99 but it doesn't work.

OK! It looks like the package fails to precompile, so it's hard to tell if anything in the code itself is wrong. I would suggest the following:

Then tackle the tasks here, as discussed.

Also, try to remember to sometimes apply the linter:

using JuliaFormatter
JuliaFormatter.format(".")
pat-alt commented 3 months ago

@Rockdeldiablo there were some small bugs in the final test file. I've also updated various projects (., docs, test) and applied the linter

using JuliaFormatter
JuliaFormatter.format(".")
codecov[bot] commented 3 months ago

Codecov Report

Attention: Patch coverage is 89.61039% with 8 lines in your changes missing coverage. Please review.

Project coverage is 94.95%. Comparing base (8d1a154) to head (8523eed). Report is 11 commits behind head on main.

Files Patch % Lines
src/mlj_flux.jl 88.40% 8 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #92 +/- ## ========================================== - Coverage 95.34% 94.95% -0.40% ========================================== Files 18 21 +3 Lines 602 575 -27 ========================================== - Hits 574 546 -28 - Misses 28 29 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

pasq-cat commented 3 months ago

@Rockdeldiablo there were some small bugs in the final test file. I've also updated various projects (., docs, test) and applied the linter

using JuliaFormatter
JuliaFormatter.format(".")

i was waiting to apply the formatter until everything worked. but how did you solve the issue with MLJBase.predict during the test? it kept giving me error. Also, i add to pass rng as a second argument to build otherwise mljbase kept complaining. The issue is that rng is already in the laplace struct so it doesn't make much sense to pass it as indipendent argument but mljbase seems to require it.

pat-alt commented 3 months ago

i was waiting to apply the formatter until everything worked. but how did you solve the issue with MLJBase.predict during the test? it kept giving me error. Also, i add to pass rng as a second argument to build otherwise mljbase kept complaining. The issue is that rng is already in the laplace struct so it doesn't make much sense to pass it as indipendent argument but mljbase seems to require it.

I didn't have any issues with that, but before even running I saw the warning about conflicting function imports related to MLJ (LaplaceRedux and MLJ both define predict). So I avoided importing the entire namespace of MLJBase (I removed MLJ as a dep).

pat-alt commented 3 months ago

@MojiFarmanbar to see

pat-alt commented 3 months ago

I think this is actually pretty close now πŸ‘πŸ½ collecting a few observations below as I'm working on the code:

Edit: more specifically, it appears that calling update after increasing epochs by 3 reruns the model for epochs + 3 epochs, as opposed to just 3 more epochs. This is because MLJModelInterface.is_same_except(...) returns false (I'm investigating why exactly).

julia> MLJBase.update(model, 2, chain, cache, X, y)
[ Info: Loss is 12.27
[ Info: Loss is 0.4794
[ Info: Loss is 0.4555
[ Info: Loss is 0.4437
[ Info: Loss is 2.295
[ Info: Loss is 0.455
[ Info: Loss is 0.4554
[ Info: Loss is 0.4562
[ Info: Loss is 170900.0
[ Info: Loss is 457.6
[ Info: Loss is 43.38
[ Info: Loss is 1.24
[ Info: Loss is 0.9924
[ Info: From train
Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1))
[ Info: From fitresult
Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1))
[ Info: From fitresult
(Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1)), LaplaceRegression(builder = MLP(hidden = (16, 8), …), …))
pat-alt commented 3 months ago

@Rockdeldiablo let me take over here for a moment, think it's just a few more minor issues that I can hopefully fix. Then that gives you an opportunity to finish the remaining tasks on #97 πŸ˜ƒ

pat-alt commented 3 months ago

@Rockdeldiablo I've added the following changes now to make things work without having to overload the update method or changing it in MLJFlux:

Now that tests are passing, there are a few more things to do (possibly in a new issue + PR) if you like.

For now, feel free to to focus on the other PR, just ping me and @MojiFarmanbar when you come back to this one. I need to move on to other things for now.

pat-alt commented 3 months ago

Let's just move the pending tasks above into new issues and then merge this one.

pasq-cat commented 2 months ago

Let's just move the pending tasks above into new issues and then merge this one.

ahh i just saw this message. how do i move the pending tasks in new issues?