Closed leviyevalex closed 1 year ago
Hi @leviyevalex I think this is from another package you are using.
I couldn't get your code to run. Looking at your example I notice a few error:
np.newaxis
should be jnp.newaxis
or simply None
.model
is defined elsewhere in your code, but is also redefined at model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
. heterodyne_minusLogLikelihood
method of this model
so the likelihood definition wouldn't make sense. I'll just replace that with -jnp.sum(x**2)
since I don't know what your likelihood is.model
for the lower and upper bounds Prior(tfpd.Uniform(low=model.lower_bound, high=model.upper_bound), name='x')
but that points to the wrong model class.X
but the arg is lowercase x
.If I fix these problems it runs fine:
from jax.config import config
config.update("jax_enable_x64", True)
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
from jaxns import ExactNestedSampler
from jaxns import Model
from jaxns import PriorModelGen, Prior
from jaxns import TerminationCondition
tfpd = tfp.distributions
def log_likelihood(x):
return -jnp.sum(x ** 2)
def prior_model() -> PriorModelGen:
x = yield Prior(tfpd.Uniform(low=0., high=1.), name='x')
return x
if __name__ == '__main__':
model = Model(prior_model=prior_model,
log_likelihood=log_likelihood)
exact_ns = ExactNestedSampler(model=model, num_live_points=200, num_parallel_samplers=1,
max_samples=1e4)
termination_reason, state = exact_ns(random.PRNGKey(42),
term_cond=TerminationCondition(live_evidence_frac=1e-4))
results = exact_ns.to_results(state, termination_reason)
exact_ns.summary(results)
Also looking at the error that you posted it looks like whatever code produces heterodyne_minusLogLikelihood
is where the problem is.
Going to close this, as it looks like there was a problem with the OP's code and that it has been resolved. Feel free to open again if something's not clear.
Describe the bug AttributeError: module 'jax' has no attribute 'Array'
To Reproduce ` from jax.config import config
config.update("jax_enable_x64", True)
import pylab as plt import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jax import vmap
import jaxns from jaxns import ExactNestedSampler from jaxns import Model from jaxns import PriorModelGen, Prior from jaxns import TerminationCondition from jaxns import analytic_log_evidence
tfpd = tfp.distributions
%%
def log_likelihood(x): X = x[np.newaxis, ...] return -1 * model.heterodyne_minusLogLikelihood(X)
def prior_model() -> PriorModelGen: x = yield Prior(tfpd.Uniform(low=model.lower_bound, high=model.upper_bound), name='x') return x
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
ns = exact_ns = ExactNestedSampler(model=model, num_live_points=200, num_parallel_samplers=1, max_samples=1e4)
termination_reason, state = exact_ns(random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4)) results = exact_ns.to_results(state, termination_reason)
`
Expected behaviour Crash occurs before sampler begins
Screenshots
Additional context Python 3.10.4 JAX 0.3.15