jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
162 stars 6 forks source link

numpyro models run in numpyro but not using bayeux #51

Open theorashid opened 4 months ago

theorashid commented 4 months ago

A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.

numpyro==0.15.0
bayeux-ml==0.1.12
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1

key = random.PRNGKey(0)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,))

def model():
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    numpyro.sample("y", dist.Normal(alpha, sigma), obs=data)

# this runs fine, samples only from alpha and sigma and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work and seems to sample from the observed sites
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0))

# it would also be nice to write the numpyro model as def model(data=None)
# and call bayeux as bx.Model.from_numpyro(model, data=data)
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1
true_beta = 0.8

key = random.PRNGKey(0)
x = jnp.linspace(0, 1, N)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,)) + true_beta * x

def model():
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    beta = numpyro.sample("beta", dist.Normal(0, 3))
    mu = alpha + beta * x
    numpyro.sample("y", dist.Normal(mu, sigma), obs=data)

# this runs fine and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=2)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0), num_chains=2)

# mul got incompatible shapes for broadcasting: (2,), (100,).
# issue with multiple chains