gerdm / bayes

Neat Bayesian machine learning examples
54 stars 7 forks source link

New window_adaptation syntax #2

Closed forgi86 closed 1 month ago

forgi86 commented 1 month ago

Hello, I an new to blackjax and recently came across your nice examples. However, I had to change a few line of codes to use recent versions of the jax/blackjax ecosystem, in particular the adaptation algorithm. For instance, in bayesian-neural-network.ipynb, I had to change the lines below the definition of the "potential" to:

warmup = blackjax.window_adaptation(blackjax.nuts, potential)
(state, parameters), _ = warmup.run(key_warmup, params)

kernel = blackjax.nuts(potential, **parameters).step
states = inference_loop(key_samples, kernel, state, num_steps)
sampled_params = states.position
gerdm commented 1 month ago

Hi @forgi86,

Could you please make a PR request with these new changes?

Thanks!

forgi86 commented 1 month ago

Hello @gerdm,

I opened a PR, but only fixed the bayesian-neural-network example. I was trying to fix the bnn-hierarchical-flax example (that is the most interesting for my current activities), but there seems to be another bug there.

I can't evaluate the potential, if I run

potential(params_all)

it throws a ValueError: Arity mismatch between trees

Unfortunately I don't have time to look into it at the moment...

forgi86 commented 1 month ago

OK, I also fixed the bug in the hbnn potential. Variable params_sigma_tree in build_sigma_tree had one more singletone initial dimension than needed (perhaps a change in linen's pytree structure?)

All fixed in my fork (https://github.com/forgi86/bayes). I also made a few changes to remove warnings for deprecated stuff.

gerdm commented 1 month ago

Thanks for your contribution @forgi86!

I'll take a look at this tonight.

gerdm commented 1 month ago

Closed by #3