jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
177 stars 8 forks source link

<class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type #22

Closed dirknbr closed 10 months ago

dirknbr commented 10 months ago

I get the below error when I run this simple TFP NUTS instance

x = jnp.array([1.1, 2, 3])

@tfd.JointDistributionCoroutineAutoBatched
def model():
  mu = yield tfd.Normal(0, 1, name='mu')
  sigma = yield tfd.Gamma(1, 1, name='sigma')
  yield tfd.Normal(mu, sigma, name='observed')

bx_model = bx.Model.from_tfp(model.experimental_pin(observed=x))
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))

error

TypeError: Value <tf.Tensor: shape=(8, 3), dtype=float32, numpy=
array([[-... ]], dtype=float32)> with type <class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type
ColCarroll commented 10 months ago

Hey! Thanks for opening this -- I think here the problem is that you are using the tensorflow backend of TFP. Changing

import tensorflow_probability as tfp
tfd = tfp.distributions

to

import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

should make it work. It just so happens that TFP doesn't quite broadcast everything correctly in this case, so you'll also have to change the last line of your model to

  yield tfd.Normal(mu[..., None], sigma, name='observed')

the sigma gets automatically broadcast, but you could also pad dimensions there!

dirknbr commented 10 months ago

thank you, that works