Closed Joshuaalbert closed 2 years ago
It hangs with dynamic or adaptive refinement on. Only static sampling works, and only below a certain number of live points. Bisecting the problem, it seems it starts to hang for num_live_points >= 2126.
This works with the dual moon example:
ns = NestedSampler(log_likelihood, prior_chain, samples_per_step=2125)
results = jit(ns)(random.PRNGKey(42),adaptive_evidence_stopping_threshold=None)
Turning on 64-bit by putting this before any imports resolves the problem, any hints that it's related to hitting the likelihood maximum bound, and squishing the samples onto a plateau.
from jax.config import config
config.update("jax_enable_x64", True)
In dual moon, if more slices taken then it hangs.