Closed dirknbr closed 10 months ago
Hey! Thanks for opening this --
I think here the problem is that you are using the tensorflow
backend of TFP. Changing
import tensorflow_probability as tfp
tfd = tfp.distributions
to
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
should make it work. It just so happens that TFP doesn't quite broadcast everything correctly in this case, so you'll also have to change the last line of your model to
yield tfd.Normal(mu[..., None], sigma, name='observed')
the sigma
gets automatically broadcast, but you could also pad dimensions there!
thank you, that works
I get the below error when I run this simple TFP NUTS instance
error