probml / pyprobml

Python code for "Probabilistic Machine learning" book by Kevin Murphy
MIT License
6.53k stars 1.53k forks source link

Create: bnn_hierarchical_flax.ipynb #1008

Closed gerdm closed 2 years ago

gerdm commented 2 years ago

Description

Figures

image

image

review-notebook-app[bot] commented 2 years ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

gerdm commented 2 years ago

If I'm not mistaken, the second check seems to fail at linreg_hierarchical_numpyro.ipynb. Did I change that notebook somehow? Or is the problem the notebook I'm trying to push? @patel-zeel

murphyk commented 2 years ago

I moved all those linreg_hierarchcal variants into book2/15 but they are all deprecated. This is unrelated to your PR.

murphyk commented 2 years ago

Hi @gerdm . I think it might be best if you just add your flax code to the end of https://github.com/probml/pyprobml/blob/master/notebooks/book2/17/bnn_hierarchical.ipynb, so we can show to the user 2 different ways of obtaining the same goal- vanilla jax and then with flax. This will reduce the amount of duplicated code.

Please also compute your train and test accuracy. In the hierarchical non flax version we get

Train accuracy = 91.00%
Test accuracy = 89.56%

You should get something that is basically identical , assuming you initialize your sampler from the same initial state, etc.

In addition to just adding code to the end of the current notebook, I suggest you also add your way of plotting the raw data (with larger dot sizes etc) at the top, since I think it will make a better figure. In addition to plotting a 4x4 grid of tasks, please also show a 2x2 subset which can be used to make a smaller figure for the book. Please use latexify and pml_utils.savefig as epxlained at https://github.com/probml/pyprobml/tree/master/notebooks#ii-detailed-guidelines.

murphyk commented 2 years ago

Hi @gerdm . Your test accuracy is 86.67%, whereas before it was ~89%. Why is it so much lower? Is your flax model identical to the previous model? It seems the architecture is the same, but maybe the init is different? Are you using the same optimizer? Ideally you could check both predict functions give the same results when applied to the same input, at least at initialization. And then if you use the same RNG your results should be identical.

gerdm commented 2 years ago

Hi @murphyk,

I agree. I'll change the PR into a draft while I debug this.