Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Apache License 2.0
145 stars 10 forks source link

Non-terminating in some challenging likelihoods. #52

Closed Joshuaalbert closed 2 years ago

Joshuaalbert commented 2 years ago

In dual moon, if more slices taken then it hangs.

Joshuaalbert commented 2 years ago

It hangs with dynamic or adaptive refinement on. Only static sampling works, and only below a certain number of live points. Bisecting the problem, it seems it starts to hang for num_live_points >= 2126.

This works with the dual moon example:

ns = NestedSampler(log_likelihood, prior_chain, samples_per_step=2125)

results = jit(ns)(random.PRNGKey(42),adaptive_evidence_stopping_threshold=None)
Joshuaalbert commented 2 years ago

Turning on 64-bit by putting this before any imports resolves the problem, any hints that it's related to hitting the likelihood maximum bound, and squishing the samples onto a plateau.

from jax.config import config
config.update("jax_enable_x64", True)