Closed sheinkmana closed 1 month ago
Hi @sheinkmana, I guess the keys in prior={"bias":Cauchy(), "kernel": dist.Normal()}
do not match the ones in the neural network.
Thanks a lot, @fehiepsi! That's super helpful. (Not going to lie - it took me an embarrassingly long time to finally realize it)
Hey,
I am trying to implement a feed-forward neural network using flax module and run into a problem. SVI works fine but MCMC (tried NUTS/HMC) doesn't work whenever I specify any priors except for setting them all the same (e.g. setting
prior=dist.Normal()
works butprior={"bias":Normal(), "kernel": dist.Normal()}
doesn't). I don't get any errors, it just completes the inference in a matter of seconds and the posterior predictive distribution is no different from the prior predictive. My flax part:The model with priors from the documentation (this and more complicated versions with scaling work for SVI fine):
Would you be able to help?
Huge thanks in advance!