A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.
numpyro==0.15.0
bayeux-ml==0.1.12
import jax.numpy as jnp
from jax import random
import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
N = 100
true_alpha = 1.1
true_sigma = 0.1
key = random.PRNGKey(0)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,))
def model():
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
numpyro.sample("y", dist.Normal(alpha, sigma), obs=data)
# this runs fine, samples only from alpha and sigma and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(random.key(0))
mcmc.print_summary()
# this does not work and seems to sample from the observed sites
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0))
# it would also be nice to write the numpyro model as def model(data=None)
# and call bayeux as bx.Model.from_numpyro(model, data=data)
import jax.numpy as jnp
from jax import random
import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
N = 100
true_alpha = 1.1
true_sigma = 0.1
true_beta = 0.8
key = random.PRNGKey(0)
x = jnp.linspace(0, 1, N)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,)) + true_beta * x
def model():
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
beta = numpyro.sample("beta", dist.Normal(0, 3))
mu = alpha + beta * x
numpyro.sample("y", dist.Normal(mu, sigma), obs=data)
# this runs fine and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=2)
mcmc.run(random.key(0))
mcmc.print_summary()
# this does not work
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0), num_chains=2)
# mul got incompatible shapes for broadcasting: (2,), (100,).
# issue with multiple chains
A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.