jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
162 stars 6 forks source link

Save PyMC `Deterministic`s to idata #21

Open bwengals opened 9 months ago

bwengals commented 9 months ago

Hey, have been playing around with this a bit from PyMC, so glad this exists now! Unfortunately I'm not getting Deterministics recorded in idata.posterior. Happy to attempt PR if you point me where to start?

ColCarroll commented 9 months ago

Thanks, @bwengals!

There is no concept of Deterministic here, which is sort of funny and unsatisfying -- bayeux just takes initial parameters and a differentiable log density.

I'm not sure of a pleasant way of storing deterministics, either! If you have a function from a parameter to your deterministics, it is probably better off if the user applies that themselves.

I think these lines do the deterministics in PyMC. I'd be happy to include a snippet in the documentation -- either as a standalone page, or with the bayeux and pymc demo.

LMK what you think!

bwengals commented 9 months ago

Ah of course,

bayeux just takes initial parameters and a differentiable log density.

so that makes perfect sense that there's no mechanism to store deterministics. Maybe a snippet in the pymc demo is the easiest? Whatever you think makes sense! Thanks or pointing me to that spot in the code, I'll try it out on my end too.

juanitorduz commented 9 months ago

It would be great to have this example/feature in the documentation 🙏

juanitorduz commented 7 months ago

A possibility (which is not very elegant) is to sample the deterministic via pm.sample_posterior_predictive as described in Out of model predictions with PyMC

import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm

x = np.linspace(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 0.1, 100)

with pm.Model(coords={"n": range(x.size)}) as model:
    alpha = pm.Normal("alpha", 0, 1)
    beta = pm.Normal("beta", 0, 1)
    sigma = pm.HalfNormal("sigma", 1)
    mu = pm.Deterministic("mu", alpha + beta * x, dims=("n",))
    pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims=("n",))

bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))

with model:
    idata.extend(pm.sample_posterior_predictive(idata, var_names=["mu", "y_obs"]))

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="data")
az.plot_hdi(
    x,
    idata.posterior_predictive["y_obs"],
    fill_kwargs={"alpha": 0.5, "label": "y_obs"},
    ax=ax,
)
az.plot_hdi(
    x,
    idata.posterior_predictive["mu"],
    fill_kwargs={"alpha": 0.8, "label": "mu"},
    ax=ax,
)
ax.legend()
ax.set(title="linear model", xlabel="x", ylabel="y")

image

Note that

idata.posterior_predictive

provides the coordinates:

image


I have not managed to do it via the functions in https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py .

ColCarroll commented 7 months ago

That's interesting -- I'd like to continue with the stateless API, but I could imagine a workflow like

bx_model = bx.Model.from_pymc(model)
pytree = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), return_pytree=True)

idata = bx.postprocess_pymc(pytree, model)

where bx.postprocess_pymc starts from: https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py#L581

what do you think?

juanitorduz commented 7 months ago

Yes! That would be fantastic! I tried this path but realized I needed raw_mcmc_samples, so I thought it would be better to be handled by bayeux internally 👍

sherna90 commented 2 months ago

bayeux just takes initial parameters and a differentiable log density.

Hi, I'm trying to implement a quasi-stateful RNG to be used in ELBO https://jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html. Any ideas how to accomplish this in Bayeux?