jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
177 stars 8 forks source link

Tensorflow probability MCMC and VI methods do not work with Bambi models #33

Closed GStechschulte closed 8 months ago

GStechschulte commented 9 months ago

Tensorflow probability samplers fail when attempting to sample from a Bambi model.

import bambi as bmb
import bayeux as bx
import jax

data = bmb.load_data("ANES")
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]

model = bmb.Model("vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli")
model.build()

bx_model = bx.Model.from_pymc(model.backend.model)
bx_model.mcmc.tfp_hmc(seed=jax.random.key(0))
TypeError: float() argument must be a string or a real number, not 'ShapedArray'

This same TypeError persists when using any TFP MCMC algorithm.

When attempting to use the TFP VI method, the following error is raised

bx_model.vi.tfp_factored_surrogate_posterior(seed=jax.random.key(0))
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

I haven't looked into why these errors are happening yet. I just wanted to bring this to your attention. Since the Bambi backend model model.backend.model is a PyMC model, these errors may also happen with PyMC models.

ColCarroll commented 9 months ago

Oy, this is still the problem with new-style JAX keys that was in #29 . It should work if you change the seed to jax.random.PRNGKey(0), or update to tfp-nightly instead of tensorflow_probability.

GStechschulte commented 9 months ago

Ahhhhh. Should've looked in the closed issues. Many thanks!! It works now.