Open bwengals opened 9 months ago
Thanks, @bwengals!
There is no concept of Deterministic
here, which is sort of funny and unsatisfying -- bayeux
just takes initial parameters and a differentiable log density.
I'm not sure of a pleasant way of storing deterministics, either! If you have a function from a parameter to your deterministics, it is probably better off if the user applies that themselves.
I think these lines do the deterministics in PyMC. I'd be happy to include a snippet in the documentation -- either as a standalone page, or with the bayeux and pymc demo.
LMK what you think!
Ah of course,
bayeux just takes initial parameters and a differentiable log density.
so that makes perfect sense that there's no mechanism to store deterministics. Maybe a snippet in the pymc demo is the easiest? Whatever you think makes sense! Thanks or pointing me to that spot in the code, I'll try it out on my end too.
It would be great to have this example/feature in the documentation 🙏
A possibility (which is not very elegant) is to sample the deterministic via pm.sample_posterior_predictive
as described in Out of model predictions with PyMC
import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
x = np.linspace(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 0.1, 100)
with pm.Model(coords={"n": range(x.size)}) as model:
alpha = pm.Normal("alpha", 0, 1)
beta = pm.Normal("beta", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
mu = pm.Deterministic("mu", alpha + beta * x, dims=("n",))
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims=("n",))
bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
with model:
idata.extend(pm.sample_posterior_predictive(idata, var_names=["mu", "y_obs"]))
fig, ax = plt.subplots()
ax.plot(x, y, "o", label="data")
az.plot_hdi(
x,
idata.posterior_predictive["y_obs"],
fill_kwargs={"alpha": 0.5, "label": "y_obs"},
ax=ax,
)
az.plot_hdi(
x,
idata.posterior_predictive["mu"],
fill_kwargs={"alpha": 0.8, "label": "mu"},
ax=ax,
)
ax.legend()
ax.set(title="linear model", xlabel="x", ylabel="y")
Note that
idata.posterior_predictive
provides the coordinates:
I have not managed to do it via the functions in https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py .
That's interesting -- I'd like to continue with the stateless API, but I could imagine a workflow like
bx_model = bx.Model.from_pymc(model)
pytree = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), return_pytree=True)
idata = bx.postprocess_pymc(pytree, model)
where bx.postprocess_pymc
starts from: https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py#L581
what do you think?
Yes! That would be fantastic! I tried this path but realized I needed raw_mcmc_samples
, so I thought it would be better to be handled by bayeux
internally 👍
bayeux
just takes initial parameters and a differentiable log density.
Hi, I'm trying to implement a quasi-stateful RNG to be used in ELBO https://jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html. Any ideas how to accomplish this in Bayeux?
Hey, have been playing around with this a bit from PyMC, so glad this exists now! Unfortunately I'm not getting
Deterministic
s recorded inidata.posterior
. Happy to attempt PR if you point me where to start?