arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.61k stars 407 forks source link

`from_numpyro` doesn't work for contributed NestedSampler #2391

Open nstarman opened 1 month ago

nstarman commented 1 month ago

Describe the bug

Arviz's from_numpyro function works on many model objects in numpyro, but not the officially supported bridge with JaxNS — numpyro.contrib.nested_sampling — because NestedSampler.get_samples requires PRNG and num_samples args.

To Reproduce

>>> import arviz
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))

>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
...                           obs=labels)

>>> ns = NestedSampler(model)
>>> ns.run(random.PRNGKey(2), data, labels)

>>> arviz.from_numpyro(ns)
TypeError: NestedSampler.get_samples() missing 2 required positional arguments: 'rng_key' and 'num_samples'

Expected behavior

Arviz dispatches on the numpyro model type to pass a prng & num_samples args to the NestedSampler.get_samples. Maybe from_numpyro adds an args kwarg for specifying these values.

SamuelSonoiki commented 1 week ago

Hi I would like to work on this issue.