TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

Implement Riemannian HMC #305

Closed xukai92 closed 1 year ago

xukai92 commented 1 year ago

This is a draft to implement Riemannian HMC. There are many things to discuss. I put the high-level points here while leaving more specific ones in the code.

To-dos

How to play with this PR

I provided a notebook (which contains the same content as the test/experimental/riemannian_hmc.jl file) to play with the code. The notebook has some simple validation on the implementation and also shows the current numerical issue of SoftAbs. I highly recommend you to try this.


[1] Betancourt, M., 2013, August. A general metric for Riemannian manifold Hamiltonian Monte Carlo. In International Conference on Geometric Science of Information (pp. 327-334). Springer, Berlin, Heidelberg.

torfjelde commented 1 year ago

One thing that is worth discussing wrt. this PR is how much of this functionality we should put into the metric types themselves vs. just implementing it all as a change of variables.

If we instead consider these position-dependent metrics as a change of variables for the potential instead, we don't have to deal with non-separable Hamiltonians internally and hence should hopefully make hte internals simpler.

Are there any obvious downsides to this?

xukai92 commented 1 year ago

One thing that is worth discussing wrt. this PR is how much of this functionality we should put into the metric types themselves vs. just implementing it all as a change of variables. If we instead consider these position-dependent metrics as a change of variables for the potential instead, we don't have to deal with non-separable Hamiltonians internally and hence should hopefully make hte internals simpler. Are there any obvious downsides to this?

I also think so as I put a note here: https://github.com/TuringLang/AdvancedHMC.jl/blob/master/src/metric.jl#L116. Ideally we could interface Bijectors.jl here and do more than just affine transformation.

With this being said, the best route to implement this in AHMC is to do an internal refactoring PR first and then implement here. I feel that would require quite a bit of discussion.

Besides, I think most of the callback change to handle non-separable Hamiltonians that introduced in this PR is still needed there.

torfjelde commented 1 year ago

I also think so as I put a note here

Ah didn't catch that; cool! Btw, are we aware of any downsides to just doing a reparameterization for something like RHMC, given that the original paper didn't do this but instead went the route of Genearlized Leapfrog, which requires a root-solve in to determine the implictly defined step?

xukai92 commented 1 year ago

Btw, are we aware of any downsides to just doing a reparameterization for something like RHMC, given that the original paper didn't do this but instead went the route of Genearlized Leapfrog, which requires a root-solve in to determine the implictly defined step?

I don't quite get the question. What the original paper does can also be seen as a form of reparameterization and we are doing the same thing together with genearlized Leapfrog, for which the implict updates are done by fixed-point iterations.

xukai92 commented 1 year ago

Updates as of ed15539

  1. Introduced types to control how Hessian is mapped (https://github.com/TuringLang/AdvancedHMC.jl/blob/kx/rhmc-draft/research/src/riemannian_hmc.jl#L136-L155). This gives an easy way to access internal parameters of the SoftAbs map and specialize the internals based on different maps (IdentityMap or SoftAbsMap).
    • PS: I made this choice mostly because of the easiness of implementation and I'm happy to consider other designs.
  2. Fixed a bug in constructing J (https://github.com/TuringLang/AdvancedHMC.jl/blob/kx/rhmc-draft/research/src/riemannian_hmc.jl#L258) by adding a missing coefficient to csch
  3. The incorrect way of handling numerical issue is removed and any handling is done by users (e.g. https://github.com/TuringLang/AdvancedHMC.jl/blob/kx/rhmc-draft/research/tests/riemannian_hmc.jl#L40)
  4. Performed Geweke test for RHMC (), which indicates that there are still bugs in the RHMC sampler with SoftAbs metric: image The test is done by the model below
    @model function TuringFunnel(θ=missing, x=missing)
    if ismissing(θ)
        θ = Vector(undef, 2)
    end
    θ[1] ~ Normal(0, 3)
    s = exp(θ[1] / 2)
    θ[2] ~ Normal(0, s)
    x ~ Normal(0, s)
    return θ, x
    end

    with MCMCDebugging.jl.

PS: For reference, the test results of a normal HMC sampler under the same model is below: image

I'm looking into the codes to see where the bug might be but if anyone has some ideas for debugging or intuition of where the bug might be, please let me know.

xukai92 commented 1 year ago

Note of meeting with Hong on 23 Feb

Potential optimization

Tests

yebai commented 1 year ago

@xukai92 I did some minor refactoring. I am happy to merge this PR as-is and revisit some design issues (e.g. ∂H∂θ_cache) when the research is complete.