Open sethaxen opened 1 year ago
The API is kind of semi-exposed, not official, but has been stable for a long time. See here. It is documented, sort of, see also mcmc_keep_warmup
.
Here is an MWE:
using DynamicHMC, LogDensityTestSuite, ForwardDiff, LogDensityProblemsAD, Random
ℓ = StandardMultivariateNormal(5)
∇ℓ = ADgradient(:ForwardDiff, ℓ)
rng = Random.GLOBAL_RNG
wu = DynamicHMC.default_warmup_stages()
function extract_initialization(state)
(; Q, κ, ϵ) = state.final_warmup_state
(; q = Q.q, κ, ϵ)
end
state1 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[1:1])
state2 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[2:2],
initialization = extract_initialization(state1))
state3 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[3:3],
initialization = extract_initialization(state2))
# just keep doing this, and run the last stage with as many samples as you need
Please keep the issue open even if this answers your question, I would like to expose this part of the API; occasionally I use it too.
Thanks! This seems to work well! Here is how I separately time the entire warm-up phase and the sampling phase:
function extract_initialization(state)
(; Q, κ, ϵ) = state
return (; q=Q.q, κ, ϵ)
end
function dhmc_warmup(
rng::Random.AbstractRNG,
ℓ;
initialization=(),
warmup_stages=DynamicHMC.default_warmup_stages(),
kwargs...,
)
initialization_final = foldl(warmup_stages; init=initialization) do init, stage
result = DynamicHMC.mcmc_keep_warmup(
rng, ℓ, 0; warmup_stages=(stage,), initialization=init, kwargs...
)
return extract_initialization(result.final_warmup_state)
end
return initialization_final
end
function dhmc_sample(rng::Random.AbstractRNG, ℓ, ndraws; initialization, kwargs...)
return DynamicHMC.mcmc_with_warmup(
rng, ℓ, ndraws; warmup_stages=(), initialization, kwargs...
)
end
Then I call each of these functions with @timed
I would like to be able to separately time warmup and post-warmup sampling. A stretch goal would be the ability to time each individual stage of the warm-up. Is this possible with just API functions?