joshspeagle / dynesty

Dynamic Nested Sampling package for computing Bayesian posteriors and evidences
https://dynesty.readthedocs.io/
MIT License
361 stars 79 forks source link

Samples can be lost if interrupted during add_live_points #490

Open ColmTalbot opened 1 day ago

ColmTalbot commented 1 day ago

Dynesty version initially noticed in 2.1.4 via conda and verified on master

Describe the bug When removing added live points the last self.nlive points are removed. However, if the add_live_points method is interrupted early for some reason

Setup The application I have where this is causing an issue is jobs being interrupted on a cluster scheduled with HTCondor and so reproducibility is difficult. I've verified that I can create the issue by manually changing add_live_points to only yield a subset of the current live points.

In case it is relevant I'm not using the built-in dynesty checkpointing, but a different system that predates it.

Dynesty output The main thing users will see is the following warning being triggered if attempting to make a runplot after this.

/Users/colmtalbot/mambaforge/envs/bilby-clean/lib/python3.11/site-packages/dynesty/plotting.py:243: UserWarning: The number of iterations and samples differ by an amount that isn't the number of final live points. `mark_final_live` has been disabled.
  warnings.warn("The number of iterations and samples differ "

Proposed solution Since add_live_points doesn't increment self.it, I think it should be safe to just change the line linked above to be

-                    del self.saved_run[k][-self.nlive:]
+                    del self.saved_run[k][self.it:]

I've tested this and it works fine (the runplot still complains). It seems to work for the dynamic sampler, but I'm less familiar with that, so I'm not sure if there are additional subtleties.

Happy to open a PR if you're happy with this solution @segasai

segasai commented 1 day ago

Thanks for the report! Since you are talking about interruptions, are you relying on the dynesty check-pointing then ? Regarding the PR, I don't yet fully understand the problem. I'll need to think a bit about it.

One problem also is that I generally avoided touching the plotting code in the past (because it's not that well tested), so I'm less familiar with it.

ColmTalbot commented 1 day ago

This is being run through Bilby and we have our own checkpointing method, it basically comes down to pickle dumping the dynesty sampler object.

segasai commented 1 day ago

Okay. Then it's a bit of a problem. In dynesty's checkpointing that I implemented, i specifically do not check-point during the add_live_points(), because it's a pain to correctly preserve the state there. My first suggestion would be to do the same. (and since addlivepoints takes a tiny fraction of time, it's a not an issue) I'll need to think more about your suggested change, but I'd be reluctant to make it, unless it's it's clear it's a dynesty's bug, rather than issue with the checkpointing scheme.

ColmTalbot commented 1 day ago

The main reason I haven't switched to using the new dynesty checkpoint system is that it doesn't allow on-demand checkpoints (presumably for this reason) such as when running on shared computing resources that may remove the job at random intervals.

I would really appreciate it if you would consider the one-line change proposed.

ColmTalbot commented 1 day ago

I'll try to generate a MWE.

segasai commented 1 day ago

I will need to look at the change in more detail, as I need to be convinced that's correct for dynesty on it own. But in general, I don't believe dynesty can be correctly restored if interrupted at any random point. I'd be happy to discuss what kind of checkpointing would work for you, and maybe adjusting what we currently have (that's separate discussion). I.e. if you are given x-amount of time to do the checkpoint, one could change the current 'every N sec' scheme to something like 'checkpoint if signaled'.

ColmTalbot commented 1 day ago

Here's an example that emulates being interrupted by a manual keyboard interrupt.

Stopping during add_live is quite rare as it is fairly fast, in practice I'm seeing it happen for a few % of long-running analyses that get interrupted up to a few tens of times each, so maybe a little less than one percent of the time. Admittedly, in Bilby we call this method more than may be your recommendation so that we can make plots to track progress.

If an option to trace/run plots during a checkpoint (or on some other cadence,) I think that could simplify how we use the dynesty quite significantly.

We have a checkpoint if signaled method implemented, which is similar in spirit to the method below using signal, just catching a signal rather than a KeyboardInterrupt that could probably be ported quite easily.

import sys

import dynesty
import numpy as np
from dynesty.plotting import runplot

# setup taken from one of the examples in the repository
rstate = np.random.default_rng(56101)

m_true = -0.9594
b_true = 4.294
f_true = 0.534

N = 50
x = np.sort(10 * rstate.uniform(size=N))
yerr = 0.1 + 0.5 * rstate.uniform(size=N)
y_true = m_true * x + b_true
y = y_true + np.abs(f_true * y_true) * rstate.normal(size=N)
y += yerr * rstate.normal(size=N)

def loglike(theta):
    m, b, lnf = theta
    model = m * x + b
    inv_sigma2 = 1.0 / (yerr**2 + model**2 * np.exp(2 * lnf))

    return -0.5 * (np.sum((y-model)**2 * inv_sigma2 - np.log(inv_sigma2)))

# prior transform
def prior_transform(utheta):
    um, ub, ulf = utheta
    m = 5.5 * um - 5.
    b = 10. * ub
    lnf = 11. * ulf - 10.

    return m, b, lnf

# Read a checkpoint if it exists as a test of the failure, try making a run plot
# we should see a warning if there is a failure
try:
    dsampler = dynesty.utils.restore_sampler("checkpoint.pkl")
    dres = dsampler.results
    print(dres["niter"], len(dres["logl"]))
    if dsampler.added_live:
        try:
            runplot(dres)
        except ValueError:
            pass
except FileNotFoundError:
    dsampler = dynesty.NestedSampler(
        loglike, prior_transform, ndim=3, bound='multi', sample='rwalk', rstate=rstate, nlive=1000)
try:
    for _ in range(10):
        if dsampler.added_live:
            dsampler._remove_live_points()
        dsampler.run_nested(checkpoint_file="checkpoint.pkl", dlogz=0.1, resume=True, maxiter=1000)
except KeyboardInterrupt:
    # If we've interrupted while adding live points, we land in the bug
    print(dsampler.added_live)
    dynesty.utils.save_sampler(dsampler, "checkpoint.pkl")
    sys.exit(1)

For testing, I added a sleep into the loop in add_live_points to more easily identify it.

ColmTalbot commented 1 day ago

I think a change to call add_live_points fewer times is possible on my end, I still think that this is a bug in dynesty that should be addressed though.

segasai commented 1 day ago

Thank you for providing an example.

Two points here

1) From what you described it seems you can just use the existing dynesty's checkpointer. Checkpoint every 60 seconds and it's guaranteed to work and no need to fiddle with added points. And you'll be able to make plots with it as well. I assume the overhead of checkpointing must be tiny, so it shouldn't be an issue. If the pickle overhead is an issue, one can think of a solution, where dynesty stores the consistent state of an object inside the sampler, but don't pickle it, and then it pickles upon request (or you can pickle it externally)

2) In your code, your expectation of getting a consistent state of the sampler upon interruption in random place of the code is not justified. There are plenty of places in the code, if you interrupt there, you'll get an inconsistent object -- you may not notice it or it will fail rarely, but there are tons of places like that. Guaranteeing consistency upon interruption in random place would require dynesty operating like a transaction, i.e. atomically, and that's not how code is structured, and it cannot provide such guarantees. Therefore even trying the approach 2 is not useful (based on my understanding of dynesty's code).

So I think what needs to be done is understand how you can fit with the existing approach that dynesty uses already (i.e. where consistent state of the sampler can be obtained/and saved at specific moments)
If you want regularly made plots, I'd say the best way for that is just use the checkpointed sampler and do that externally. If it is really critical to have it internally, maybe one could maybe think of some kind of call-back mechanism. I.e. you provide a function that is either called at some intervals or at checkpoint.