tpapp / DynamicHMC.jl

Implementation of robust dynamic Hamiltonian Monte Carlo methods (NUTS) in Julia.
Other
244 stars 21 forks source link

Implement fix for NUTS missing u-turns #115

Closed sethaxen closed 3 years ago

sethaxen commented 4 years ago

There was recently some discussion on Stan Discourse that resulted in this PR in Stan and this PR in PyMC3. They make some changes to the NUTS criterion to handle a previously undiscovered case where the NUTS criterion failed to catch a U-turn. Has this been implemented here yet? (I haven't yet researched what the fix entailed)

Self-contained example that demonstrates the problem

This post gave an example model that exhibits the failed behavior. I've reproduced it here:

using DynamicHMC, LogDensityProblems, Distributions, Random
import LogDensityProblems: capabilities, dimension, logdensity, logdensity_and_gradient

struct StdNormalProblem{N} end

StdNormalProblem(N::Int) = StdNormalProblem{N}()
(p::StdNormalProblem)(θ) = sum(logpdf.(Normal(), θ))
capabilities(::Type{<:StdNormalProblem}) = LogDensityProblems.LogDensityOrder{1}()
dimension(p::StdNormalProblem{N}) where {N} = N
logdensity(p::StdNormalProblem, θ) = p(θ)
logdensity_and_gradient(p::StdNormalProblem, θ) = (logdensity(p, θ), -θ)

max_depth = 12
rng = MersenneTwister(13)
nsat = sum(1:20) do _
    results = mcmc_with_warmup(
        rng,
        StdNormalProblem(200),
        1000;
        algorithm = DynamicHMC.NUTS(max_depth = max_depth),
        reporter = NoProgressReport(),
    )
    sum(getfield.(results.tree_statistics, :depth) .≥ max_depth)
end

On my machine, nsat is 9, i.e. 9/20000 (0.045 %) trajectories didn't detect a u-turn, which is on the same order as the Stan example before the fix was merged (0.12%).

tpapp commented 4 years ago

Thanks for bringing this up, and the very thorough issue report. It is on my radar and I plan to fix this. Feel free to ping me if I don't get to it before January.

tpapp commented 3 years ago

@sethaxen: apologies that this took such a long time, the fix is simple but I wanted to understand the math first. Thanks again for suggesting this.