Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
142 stars 10 forks source link

Can't get sampler to work! #82

Closed leviyevalex closed 1 year ago

leviyevalex commented 1 year ago

Describe the bug AttributeError: module 'jax' has no attribute 'Array'

To Reproduce ` from jax.config import config

config.update("jax_enable_x64", True)

import pylab as plt import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jax import vmap

import jaxns from jaxns import ExactNestedSampler from jaxns import Model from jaxns import PriorModelGen, Prior from jaxns import TerminationCondition from jaxns import analytic_log_evidence

tfpd = tfp.distributions

%%

def log_likelihood(x): X = x[np.newaxis, ...] return -1 * model.heterodyne_minusLogLikelihood(X)

def prior_model() -> PriorModelGen: x = yield Prior(tfpd.Uniform(low=model.lower_bound, high=model.upper_bound), name='x') return x

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = exact_ns = ExactNestedSampler(model=model, num_live_points=200, num_parallel_samplers=1, max_samples=1e4)

termination_reason, state = exact_ns(random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4)) results = exact_ns.to_results(state, termination_reason)

`

Expected behaviour Crash occurs before sampler begins

Screenshots image

Additional context Python 3.10.4 JAX 0.3.15

Joshuaalbert commented 1 year ago

Hi @leviyevalex I think this is from another package you are using.

I couldn't get your code to run. Looking at your example I notice a few error:

If I fix these problems it runs fine:

from jax.config import config

config.update("jax_enable_x64", True)

import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp

from jaxns import ExactNestedSampler
from jaxns import Model
from jaxns import PriorModelGen, Prior
from jaxns import TerminationCondition

tfpd = tfp.distributions

def log_likelihood(x):
    return -jnp.sum(x ** 2)

def prior_model() -> PriorModelGen:
    x = yield Prior(tfpd.Uniform(low=0., high=1.), name='x')
    return x

if __name__ == '__main__':
    model = Model(prior_model=prior_model,
                  log_likelihood=log_likelihood)

    exact_ns = ExactNestedSampler(model=model, num_live_points=200, num_parallel_samplers=1,
                                       max_samples=1e4)

    termination_reason, state = exact_ns(random.PRNGKey(42),
                                         term_cond=TerminationCondition(live_evidence_frac=1e-4))
    results = exact_ns.to_results(state, termination_reason)
    exact_ns.summary(results)

Also looking at the error that you posted it looks like whatever code produces heterodyne_minusLogLikelihood is where the problem is.

Joshuaalbert commented 1 year ago

Going to close this, as it looks like there was a problem with the OP's code and that it has been resolved. Feel free to open again if something's not clear.