TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
237 stars 41 forks source link

Make `HMCState` stores `rng` in `AbstractMCMC` interface. #314

Closed yebai closed 1 year ago

yebai commented 1 year ago

https://github.com/TuringLang/AdvancedHMC.jl/blob/82de3ff64c88316f7ebdf1e407e3077b3870c8af/src/abstractmcmc.jl#L166

We should add rng to the HMC state for continuing HMC sampling from a previous stopping point.

yebai commented 1 year ago

@JaimeRZP can we fix this along with the efforts in #325?

In general, we want to capture the entire sampling state after an AHMC.step call and pass it to the next step call, so it is resumable and fully reproducible.

JaimeRZP commented 1 year ago

Done!

torfjelde commented 1 year ago

But rng is an argument to AbstractMCMC.step @yebai ; why do we want to also include it in the state? This is contrary to all other implementations of the AbstractMCMC-interface.

If you want it for reproducibility, then a) I'm a bit uncertain what scenario you have in mind, and b) you'd have to deepcopy the rng at every step.

yebai commented 1 year ago

My original thought is to use non-mutating rng, so we need to return the new rng state after each step call and pass it to the next step call. The current design assumes mutating the rng state shared by all step calls, which is not ideal.

See, e.g., https://juliarandom.github.io/RandomNumbers.jl/stable/man/random123/

EDIT: we can use counter-based rngs -- saving the state is equivalent to saving the counter value, which is cheap.

torfjelde commented 1 year ago

I'm fully with you that keeping track of the rng is potentially an idea to explore, but my argument is that this is not the current way we're doing things in AbstractMCMC, and so we should not start doing this here.

If we decide to add rng to the state, then we should do this everywhere; not just in one of the packages.

Moreover, the rng is also passed to the callback and so the callback function is free to do whatever it wants with the rng, e.g. saving it. So even for resuing chains, I don't see why the rng needs to be added to the state itself; it's already available.

EDIT: we can use counter-based rngs -- saving the state is equivalent to saving the counter value, which is cheap.

Am fully aware this is an option, but then you're suggesting always forcing usage of Random123? This is clearly not ideal.