johannesulf / nautilus

Neural Network-Boosted Importance Nested Sampling for Bayesian Statistics
https://nautilus-sampler.readthedocs.io
MIT License
73 stars 8 forks source link

[Feature request] Limit maximum runtime of `sampler.run()` #46

Closed timothygebhard closed 6 months ago

timothygebhard commented 7 months ago

Hi Johannes!

First things first, thanks a lot for your work — nautilus is undoubtedly my favorite nested sampling library! 🙂

I have one request / suggestion / question: Due to the way the cluster works on which I use nautilus, I need to be able to limit the maximum runtime of run() and then possibly restart the job from a checkpoint. So far, I have been using an approach inspired by this solution, but I noticed that this is a bit too radical: If the timeout is reached while nautilus is writing data to an HDF file, I end up with a corrupted checkpoint from which I cannot resume (usually because of some shape mismatch). Therefore, I was wondering if you could imagine adding a max_runtime argument to the run method that would allow a more graceful way of limiting the runtime? I'd be fine with this being a "soft" limit that is only enforced whenever a checkpoint is created.

Alternatively, it would probably help me out already if I could manually create a checkpoint after interrupting the run() call. Is it sufficient if I do something like this, or will this lose information / still produce corrupted checkpoints?

with time_limit(max_runtime):
    sampler.run(...)
sampler.write(sampler.filepath, overwrite=True)

I noticed that there is also write_shell_update(), and I wasn't sure what to do about that because the shell information seems not that easily accessible after killing run()...

Any other idea would of course be welcome, too!

Thanks a lot in advance, — Timothy

johannesulf commented 7 months ago

Hi Timothy, thanks for reaching out! You're right that interrupting nautilus midway can lead to corrupted HDF5 checkpoint files. Implementing a time limit shouldn't be too difficult but I'm wondering whether the n_like_max keyword argument of run() would already solve your problem. Once nautilus hits a specific maximum number of likelihood calls given by n_like_max, it'll stop. By default that maximum is infinity but you can set it to a lower value.

I think the solution you proposed in your code may produce errors. Your use of the writefunction makes sense. However, interrupting the sampler midway may lead to other problems if important function calls are interrupted in the middle. So let me know if n_like_max solves your problem.

timothygebhard commented 7 months ago

I'm not sure if n_like_max solves my problem — after all, I do want to run nautilus until convergence (albeit over multiple jobs), and I don't know a priori how many likelihood evaluations will be needed for that. Also, as far as I can tell, the run() method does not seem to return anything that indicates if the function finished because it actually converged or because the limit was reached?

I have already tried my luck at a first implementation of the max_runtime argument, but I might have very well overlooked something or taken an approach that's too simplistic...

johannesulf commented 7 months ago

Yes, I think it may be good for nautilus to return whether the run finished because the maximum number of likelihood calls (or time) was reached or because it finished normally. In the meantime, I think n_like_max would still solve your problem. Let's say you can run 1000 calls for each job. You can then make nautilus always run those 1000 calls for each job by specifying n_like_max=sampler.n_like+1000. Additionally, you'd know that nautilus is finished if sampler.n_like did not increase after calling run. Does that make sense?

timothygebhard commented 7 months ago

That does make sense, and I think I have managed to write a wrapper around the run() call that implements this logic. My main concern right now is that I need to estimate the number of likelihood calls up front (because at the end of the day, I'm still bounded by total execution time, not the number of likelihood calls). I currently do this by evaluating the likelihood once, measuring the time, dividing my target maximum runtime by this estimate, and multiplying by the size of my Pool. This seems like a rather crude estimate for the number of likelihood calls per job, though, because it ignores all the overhead?

Another idea I had was to do something like this:

finished = False
while time.time() - start_time < max_runtime:
    before = sampler.n_like
    sampler.run(..., n_like_max=sampler.n_like + 100)  # or another small-ish number
    if sampler.n_like == before:
        finished = True
        break

but I am not sure how much overhead it would introduce to stop and re-start the sampler so often? (I assume that every time run() finishes, a checkpoint is created, and, of course, I would rather not spend more time creating checkpoints than actually sampling...)

johannesulf commented 7 months ago

I think what you wrote will probably work. I'll look into implementing the new features in the next week or so. It shouldn't be too difficult and I can see how this would make things easier in certain situations.

johannesulf commented 7 months ago

Hi @timothygebhard, I implemented the requested functionality in the timeout branch. Please let me know if this fits your needs. I'll then add a couple of units tests, merge into main and release this as part of version 1.0.3.

timothygebhard commented 6 months ago

Hi @johannesulf! I just wanted to let you know that I've been running inference with the updated version for the past couple of days now and the timeout feature is working very nicely so far 🥳 Thanks a lot for implementing it so quickly!

johannesulf commented 6 months ago

Wonderful! Please don't hesitate to reach out if you have any further questions or suggestions.