tpapp / DynamicHMC.jl

Implementation of robust dynamic Hamiltonian Monte Carlo methods (NUTS) in Julia.
Other
244 stars 21 forks source link

Timing warmup and sampling #177

Open sethaxen opened 1 year ago

sethaxen commented 1 year ago

I would like to be able to separately time warmup and post-warmup sampling. A stretch goal would be the ability to time each individual stage of the warm-up. Is this possible with just API functions?

tpapp commented 1 year ago

The API is kind of semi-exposed, not official, but has been stable for a long time. See here. It is documented, sort of, see also mcmc_keep_warmup.

Here is an MWE:

using DynamicHMC, LogDensityTestSuite, ForwardDiff, LogDensityProblemsAD, Random

ℓ = StandardMultivariateNormal(5)
∇ℓ = ADgradient(:ForwardDiff, ℓ)
rng = Random.GLOBAL_RNG

wu = DynamicHMC.default_warmup_stages()

function extract_initialization(state)
    (; Q, κ, ϵ) = state.final_warmup_state
    (; q = Q.q, κ, ϵ)
end

state1 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[1:1])
state2 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[2:2],
                                     initialization = extract_initialization(state1))
state3 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[3:3],
                                     initialization = extract_initialization(state2))
# just keep doing this, and run the last stage with as many samples as you need

Please keep the issue open even if this answers your question, I would like to expose this part of the API; occasionally I use it too.

sethaxen commented 1 year ago

Thanks! This seems to work well! Here is how I separately time the entire warm-up phase and the sampling phase:

function extract_initialization(state)
    (; Q, κ, ϵ) = state
    return (; q=Q.q, κ, ϵ)
end

function dhmc_warmup(
    rng::Random.AbstractRNG,
    ℓ;
    initialization=(),
    warmup_stages=DynamicHMC.default_warmup_stages(),
    kwargs...,
)
    initialization_final = foldl(warmup_stages; init=initialization) do init, stage
        result = DynamicHMC.mcmc_keep_warmup(
            rng, ℓ, 0; warmup_stages=(stage,), initialization=init, kwargs...
        )
        return extract_initialization(result.final_warmup_state)
    end
    return initialization_final
end

function dhmc_sample(rng::Random.AbstractRNG, ℓ, ndraws; initialization, kwargs...)
    return DynamicHMC.mcmc_with_warmup(
        rng, ℓ, ndraws; warmup_stages=(), initialization, kwargs...
    )
end

Then I call each of these functions with @timed