Closed yebai closed 1 year ago
@JaimeRZP can we fix this along with the efforts in #325?
In general, we want to capture the entire sampling state after an AHMC.step
call and pass it to the next step
call, so it is resumable and fully reproducible.
Done!
But rng
is an argument to AbstractMCMC.step
@yebai ; why do we want to also include it in the state? This is contrary to all other implementations of the AbstractMCMC-interface.
If you want it for reproducibility, then a) I'm a bit uncertain what scenario you have in mind, and b) you'd have to deepcopy
the rng
at every step
.
My original thought is to use non-mutating rng, so we need to return the new rng state after each step call and pass it to the next step call. The current design assumes mutating the rng state shared by all step calls, which is not ideal.
See, e.g., https://juliarandom.github.io/RandomNumbers.jl/stable/man/random123/
EDIT: we can use counter-based rngs -- saving the state is equivalent to saving the counter value, which is cheap.
I'm fully with you that keeping track of the rng
is potentially an idea to explore, but my argument is that this is not the current way we're doing things in AbstractMCMC, and so we should not start doing this here.
If we decide to add rng
to the state, then we should do this everywhere; not just in one of the packages.
Moreover, the rng
is also passed to the callback
and so the callback function is free to do whatever it wants with the rng
, e.g. saving it. So even for resuing chains, I don't see why the rng
needs to be added to the state
itself; it's already available.
EDIT: we can use counter-based rngs -- saving the state is equivalent to saving the counter value, which is cheap.
Am fully aware this is an option, but then you're suggesting always forcing usage of Random123? This is clearly not ideal.
https://github.com/TuringLang/AdvancedHMC.jl/blob/82de3ff64c88316f7ebdf1e407e3077b3870c8af/src/abstractmcmc.jl#L166
We should add
rng
to the HMC state for continuing HMC sampling from a previous stopping point.