TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
237 stars 41 forks source link

Don't pass `Hamiltonian` around, but instead `metric` and `DifferentiableDensityModel`. #273

Open torfjelde opened 3 years ago

torfjelde commented 3 years ago

A lot of the code in AHMC "unnecessarily" requires a Hamiltonian, i.e. a struct containing metric, ℓπ, and ∂ℓπ∂θ, to be passed around. This is a bit awkward for the following reasons:

IMO it's a more composable if we instead pass around metric and DifferentiableDensityModel (which contains the latter two), or, in more generality, a AbstractMCMC.AbstractModel, and the Hamiltonian should be constructed when needed e.g. at the call-site of AdvancedHMC.step or maybe even add a default impl for DifferentiableDensityModel

function step(int::AbstractIntegrator, model::DifferentiableDensityModel, metric::AbstractMetric, z::P, n_steps::Int=1; kwargs...)
    h = Hamiltonian(model, metric)
    return step(int, h, z, n_steps; kwargs...)
end

This would allow for futher extensions down the road, e.g. maybe it's useful to allow some other AbstractModel to be used, etc. But also it would mean that we can share more code between the impl of AbstractMCMC.jl-interface and the "standard" interface of AHMC.