TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
82 stars 18 forks source link

Make `AbstractMCMC.step` function handle `rng` as part of state #116

Open yebai opened 1 year ago

yebai commented 1 year ago

For continuing MCMC sampling from a previous stopping point, we need to store the rng as part of the sampling state.

https://github.com/TuringLang/AdvancedMH.jl/blob/e1741179e2505da57945d47b7b1debbf3f0e848b/src/mh-core.jl#L83

https://github.com/TuringLang/AdvancedMH.jl/blob/e1741179e2505da57945d47b7b1debbf3f0e848b/src/mh-core.jl#L90

devmotion commented 1 year ago

Isn't that a more general issue/question that is not specific for AdvancedMH? Also with other samplers you have to save the RNG if you want to continue with exactly the same stream of random numbers. But one can easily do that by passing an explicit RNG object and storing it separately when stopping sampling, I think? I would have assumed as well that in many cases it does not matter if one continues with a different RNG or differently seeded RNG, as long as the two streams of random numbers are not correlated and e.g. the seeds are sampled randomly.

yebai commented 1 year ago

Isn't that a more general issue/question that is not specific for AdvancedMH?

Yes, this is a more general issue ideally solved by AbstractMCMC.

But one can easily do that by passing an explicit RNG object and storing it separately when stopping sampling, I think? I would have assumed as well that in many cases it does not matter if one continues with a different RNG or differently seeded RNG, as long as the two streams of random numbers are not correlated and e.g. the seeds are sampled randomly.

That works indeed for most cases. Consider a special case where we want to transfer the sampling process between machines. For example, we run a model for 10 minutes, save the states to disk and wait for the user to perform some convergence checks (or other actions). Later the user might decide to continue the sampling process for another 10 mins. Under such circumstances, we don't know how many steps of MCMC we will run under given time constraints. So the rng has to be returned and stored together with using StableRNG to guarantee full reproducibility. One natural way of "checkpointing" these rng stages is the AbstractMCMC.step function, I think.

devmotion commented 1 year ago

I think that use case is also related to https://github.com/TuringLang/AbstractMCMC.jl/issues/109.

The annoying part about handling it in step is that every sampler package has to adjust for it. I wonder if it would be sufficient to handle it in bundle_samples and pass it there, together with e.g. the final state. This seems sufficient if one uses sample. And if one uses the iterator or transducer, one handles states, RNG etc. manually anyway, so saving the RNG should be trivial?

yebai commented 1 year ago

I wonder if it would be sufficient to handle it in bundle_samples and pass it there, together with e.g. the final state.

That could work well. We can treat these rng states as meta information and store them in chains, possibly together with other sampler states (e.g. HMC preconditioning matrix, leapfrog step size).

yebai commented 1 year ago

Related https://github.com/TuringLang/AdvancedHMC.jl/issues/314

Cc @JaimeRZP maybe we can switch to non-mutating rng, then update AbstractMCMC.step to return the new rng state.