TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

No glue code #319

Closed JaimeRZP closed 1 year ago

JaimeRZP commented 1 year ago

Hi All!

My name is Jaime Ruiz Zapatero and I am 3rd year PhD in astrophysics at Oxford.

@yebai and I have been discussing possible ways of interfacing Turing and AdvancedHMC in a more general without the current glue code currently present inside Turing. This motivated by the limitations of the current interface already discussed in this PR by @sethaxen. The fundamental idea is to extract a LogDensityProblem object from a generic Turing model which then is used to build a Hamiltonian object for AdvancedHMC. This can be done as follows:

ctxt = model.context
vi = DynamicPPL.VarInfo(model, ctxt)
ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi, model, ctxt))
hamiltonian = AdvancedHMC.Hamiltonian(metric, ℓ)

Where model is a conditioned Turing model.

I have written a Neal's funnel example of how this interface can work which you can find in this notebook. Here I have also drafted some place holder functions to wrap the interface such that it is user-friendly. Essentially, I have added an additional signature to sample and created a wrapper structure for the sampler ingridients. To do so I have tried to follow the idea proposed in this issue by @yebai.

This example is already fully functional. However, it is missing an important aspect. Turing models can have priors with hard boundaries. However, samplers work best in continuous spaces. The solution is map the bounded prior space to the unbounded sampling space. Turing does this internally but to avoid relying too much on Turing one can also do the following:

ctxt = model.context
vi = DynamicPPL.VarInfo(model, ctxt)
vi_t = Turing.link!!(vi, model) # This transforms the variables to the unbounded space

# By passing vi_t as opposed to vi to ℓ, the log density function will input and output the transformed variables
ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi_t, model, ctxt))
hamiltonian = AdvancedHMC.Hamiltonian(metric, ℓ)

Then the generated samples can be transformed back to the prior space using:


function _get_dists(vi)
    mds = values(vi.metadata)
    return [md.dists[1] for md in mds]
end

dists = _get_dists(vi)
dist_lengths = [length(dist) for dist in dists]
vsyms = _name_variables(vi, dist_lengths)

function _reshape_params(x::AbstractVector)
    xx = []
    idx = 0
    for dist_length in dist_lengths
        append!(xx, [x[idx+1:idx+dist_length]])
        idx += dist_length
    end
    return xx
end

function transform(x)
    x = _reshape_params(x)
    xt = [Bijectors.link(dist, par) for (dist, par) in zip(dists, x)]
    return vcat(xt...)
end

function inv_transform(xt)
    xt = _reshape_params(xt)
    x = [Bijectors.invlink(dist, par) for (dist, par) in zip(dists, xt)]
    return vcat(x...)
end

Note that these already account for the jacobian of the transformation. Also, this might not be the most elegant way of doing the transformation.

Question: Where would you suggest writing these functions into the code? My guess would be within sampler but I don't want to mess around with the design idea.

Question: Where would you suggest incorporating the transformation functions?

Once we have interfaced with the internal sampling method, we should also be able to use AbstractMCMC to do the sample following what Turing does. I already have written something like this for a micro-canonical HMC sampler I have been working on, MicrocanonicalHMC.jl. The most involved step is to overload the AbstractMCMC.step function with AdvancedHMC equivalent.

All the best, Jaime

JaimeRZP commented 1 year ago

I have now coded the changes into AdvancedHMC src. Users should be able to sample a Turing model using AdvancedHMC directly by using:

n_samples, n_adapts = 10_000, 1_000
sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts)

or even simpler

samples, stats = sample(model, 0.1, 0.95, n_samples, n_adapts; initial_θ=initial_θ)

which will use NUTS by defautl.

The unbounded to bounded space transforms are still missing.