TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
237 stars 41 forks source link

Draft of Nutpie/Nuts-rs mass matrix adaption #312

Open svilupp opened 1 year ago

svilupp commented 1 year ago

NOT READY!

Quick draft of the new mass matrix adaption begin tested/used by PyMC team for discussion. Attempt at #311

Notes:

It's a draft -- my VS Code made several formatting changes that I'd need to unwind + I need to take out changes to Project.toml (used for running examples)

svilupp commented 1 year ago

I did a quick benchmark on a simple linear regression model and I don't see much difference in how quickly it adapts (perhaps it's too easy to see any difference)

This is the first 100 tuning draws and what values were sampled. The hypothesis was that if Nutpie is better, it would hone in on the right values faster, but as you can see it takes only c. 10 draws to get it right for both algorithms -- ie, hard to compare. comparison

# Example for Nuts-rs / Nutpie Adaptor
using AdvancedHMC, ForwardDiff
using LinearAlgebra
using Distributions
using Plots
const A = AdvancedHMC
using LogDensityProblems, TransformVariables, TransformedLogDensities, Parameters
using AbstractMCMC: LogDensityModel

# Example taken from https://www.tamaspapp.eu/DynamicHMCExamples.jl/latest/example_linear_regression/
"""
Linear regression model ``y ∼ Xβ + ϵ``, where ``ϵ ∼ N(0, σ²)`` IID.

Weakly informative prior for `β`, half-T for `σ`.
"""
struct LinearRegressionProblem{TY<:AbstractVector,TX<:AbstractMatrix,
    Tν<:Real}
    "Observations."
    y::TY
    "Covariates"
    X::TX
    "Degrees of freedom for prior."
    ν::Tν
end
function (problem::LinearRegressionProblem)(θ)
    @unpack y, X, ν = problem   # extract the data
    @unpack β, σ = θ            # works on the named tuple too
    ϵ_distribution = Normal(0, σ) # the error term
    ℓ_error = mapreduce((y, x) -> logpdf(ϵ_distribution, y - dot(x, β)), +,
        y, eachrow(X))    # likelihood for error
    ℓ_σ = logpdf(TDist(ν), σ)             # prior for σ
    ℓ_β = loglikelihood(Normal(0, 10), β) # prior for β
    ℓ_error + ℓ_σ + ℓ_β
end
# Random data
n_samples, n_adapts = 2_000, 1_000
N = 100
β = [-10.0, 2.0, -1.0, -0.3, 3, 0.1, 8]
σ = 2
X = hcat(ones(N), randn(N, length(β) - 1));
y = X * β .+ randn(N) .* σ;
p = LinearRegressionProblem(y, X, 1.0);
@info string("log density: ", p((β=β, σ=σ)))
# Transform it to unconstrained space
function problem_transformation(p::LinearRegressionProblem)
    as((β=as(Array, size(p.X, 2)), σ=asℝ₊))
end
t = problem_transformation(p)
P = TransformedLogDensity(t, p)
initial_θ = ones(length(β) + 1)
metric = DiagEuclideanMetric(length(initial_θ))

# Sampling w Default
hamiltonian = Hamiltonian(metric, P, ForwardDiff)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))
# sample(LogDensityModel(P), proposal, metric, adaptor, 100)
@time samples1, stats1 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=true);

# NUTPIE
# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs
hamiltonian = Hamiltonian(metric, P, ForwardDiff)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
pc = A.ExpWeightedWelfordVar(size(metric))
adaptor = A.NutpieHMCAdaptor(pc, StepSizeAdaptor(0.8, integrator))
@time samples2, stats2 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=true);

# # Plots
# plot how many tuning draws it takes to get to the real values
compare_coefficient = let samples1 = samples1, samples2 = samples2, β = β, σ = σ
    function plotter(idx)
        pl = plot(getindex.(samples1, idx), xlim=(0, 100), label="Default", color=palette(:default)[1])
        plot!(pl, getindex.(samples2, idx), xlim=(0, 100), label="Nutpie/Nuts-rs", color=palette(:default)[2])
        if idx <= length(β)
            hline!(pl, β[[idx]], color=:red, linestyle=:dash, label="True value")
        else # if it's not beta, it must be sigma
            hline!(pl, [1 / σ], color=:red, linestyle=:dash, label="True value (untrf σ")
        end
        return pl
    end
end
k = 8
pl = plot([compare_coefficient(i) for i in 1:k]..., layout=(k, 1), size=(600, 400 + k * 100),
    plot_title="Comparison of components", legend=:bottomright)
savefig(pl, "comparison.png")
aseyboldt commented 1 year ago

From a very quick look:

Translated from Rust; a few choices that could be challenged (eg, frequencies, window sizes, but also the fact that diag. matrix is updated on each iteration)

If you experiment with those, I'd love to see the results. :-)

I don't think the cost of updating the mass matrix will matter that much in real word problems however. This still only happens once per draw, so every couple of gradient evals. And gradient evals are usually significantly more expensive than a diag mass matrix update. The mass matrix updates in every iteration do however lead to a bit of trouble. Since we do it every time, we strictly speaking break the mcmc sampler, which can slightly bias the results, which can then bias the mass matrix estimate, which then makes sampling efficiency worse after tuning. It seemed to me that the advantages out way the costs, but I don't actually know for sure.

aseyboldt commented 1 year ago

Oh, and I'm not sure, but doesn't this look biased? image

svilupp commented 1 year ago

Quick update:

Results: Poor! I'm unable to get Nutpie to sample properly.

Initial results for the Diamonds model below (incl. the code to reproduce the example)

image

Code to reproduce the above chart

using Pkg;
using BridgeStan, AdvancedHMC, PosteriorDB, Random, StanLogDensityProblems, LogDensityProblems
using AdvancedHMC: ExpWeightedWelfordVar, NutpieHMCAdaptor, ExpWeightedWelfordVar, NutpieHMCAdaptorNoSwitch, NutpieHMCAdaptorNoGradInit, NutpieHMCAdaptorNoSwitchNoGradInit
# using DataFramesMeta
using MCMCDiagnosticTools
using Folds
using StatsPlots

# # Benchmark setup
# Seth Axen's amazing wrapper based on https://github.com/mlcolab/PathfinderBenchmarks.jl/blob/main/src/dynamichmc.jl
# wrapper to count the number of function evaluations and gradient evaluations
mutable struct EvalCountingProblem{P}
    const prob::P
    num_evals::Int
    num_grad_evals::Int
end
EvalCountingProblem(prob) = EvalCountingProblem(prob, 0, 0)

function LogDensityProblems.capabilities(::Type{<:EvalCountingProblem{P}}) where {P}
    return LogDensityProblems.capabilities(P)
end

function LogDensityProblems.dimension(prob::EvalCountingProblem)
    return LogDensityProblems.dimension(prob.prob)
end

function LogDensityProblems.logdensity(prob::EvalCountingProblem, x)
    prob.num_evals += 1
    return LogDensityProblems.logdensity(prob.prob, x)
end

function LogDensityProblems.logdensity_and_gradient(prob::EvalCountingProblem, x)
    prob.num_grad_evals += 1
    return LogDensityProblems.logdensity_and_gradient(prob.prob, x)
end
zero!(prob::EvalCountingProblem) = (prob.num_evals = 0; prob.num_grad_evals = 0)
Base.show(io::IO, m::MIME"text/plain", prob::EvalCountingProblem) = (ioinner = IOBuffer(); show(ioinner, m, prob.prob); print(io, "EvalCountingProblem($(take!(ioinner)|>String))"))

function ess_rhat(x)
    ess, rhat_bulk = MCMCDiagnosticTools.ess_rhat_bulk(x; maxlag=typemax(Int))
    rhat_tail = MCMCDiagnosticTools.rhat_tail(x)
    rhat = max.(rhat_bulk, rhat_tail)
    return (; ess, rhat)
end

#####################################
### Estimation routines
function sample_one_chain_(adaptor_type::Type{TA},metric_adaptor_type::Type{TMA}, rng_chain, prob; n_samples=1000, n_adapts=200,verbose=false) where {TA, TMA}
    D = LogDensityProblems.dimension(prob)
    initial_θ = rand(rng_chain, D)
    metric = DiagEuclideanMetric(D)
    count_prob = EvalCountingProblem(prob)
    hamiltonian = Hamiltonian(metric, count_prob)
    initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
    integrator = Leapfrog(initial_ϵ)
    proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
    # clear the counter after the step size is found
    zero!(count_prob)
    # provide an adaptor of choice
    metric_adaptor = metric_adaptor_type(size(metric))
    adaptor = adaptor_type(metric_adaptor, StepSizeAdaptor(0.8, integrator))
    samples, stats = sample(rng_chain, hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=false,verbose)
    samples = mapreduce(permutedims, vcat, samples) |> x -> reshape(x, (n_samples, 1, D))
    return samples, stats, count_prob.num_evals, count_prob.num_grad_evals
end

function sample_one_chain(adaptor_type,metric_adaptor_type, seed, rng, prob; n_samples=1000, n_adapts=200, n_tries=100,verbose=false)
    rng_chain = deepcopy(rng)
    Random.seed!(rng_chain, seed)
    i = 0
    while i < n_tries
        try
            samples, stats, num_evals, num_grad_evals = sample_one_chain_(adaptor_type,metric_adaptor_type, rng_chain, prob; n_samples, n_adapts,verbose)
            return samples, stats, num_evals, num_grad_evals
        catch err
            # @warn err
            @warn "Failed to sample, trying again ($i/$n_tries)"
            i += 1
        end
    end
    @error "Failed failed to sample after $n_tries tries"
end

function run_scenario(adaptor_type,metric_adaptor_type, rng, prob; n_samples=1000, n_adapts=200, n_chains=4, verbose=false)
    seeds = rand(rng, UInt, n_chains)
    # measure time
    time = @elapsed results = Folds.collect(sample_one_chain(adaptor_type,metric_adaptor_type, seed, rng, prob; n_samples, n_adapts, verbose) for seed in seeds);
    # extract results
    samples = hcat(getindex.(results, 1)...)
    stats = hcat(getindex.(results, 2)...)
    num_evals = getindex.(results, 3) |> sum
    num_grads_evals = getindex.(results, 4) |> sum
    # alternative
    # num_grads_evals = getindex.(stats, :n_steps) |> sum
    divergences = getindex.(stats, :numerical_error) |> sum
    ess, rhat = ess_rhat(samples)
    #
    ess_per_grad_mean = mean(ess) ./ num_grads_evals
    ess_mean = mean(ess)
    ess_min = minimum(ess) 
    ess_max = maximum(ess)
    rhat_max = maximum(rhat)
    @info "Results: $(ess_per_grad_mean) ESS/grad eval in $(time) seconds with $(divergences) divergences (max Rhat: $rhat_max)"
    return  (;ess_mean,ess_max,ess_min,num_grads_evals,ess_per_grad_mean,rhat_max, divergences, time)
end

# # Explore posteriorDB
# posterior_name = "diamonds-diamonds"
# post = posterior(pdb, posterior_name)
# mod = model(post)
# data = dataset(post)
# info(post)

# impl = implementation(mod, "stan")
# mod_code = load(impl)
# println(mod_code)

# load(data)
# ref = reference_posterior(post)
# info(ref)
# ref = DataFrame(load(ref))
# prob = StanProblem(post, ".", force=true, make_args=["STAN_THREADS=true"])
# LogDensityProblems.capabilities(prob)
# rng = Random.default_rng();
# LogDensityProblems.logdensity(prob, initial_θ)

# # Run Scenario: Diamond Model
pdb = database()
posterior_name = "diamonds-diamonds"
post = posterior(pdb, posterior_name)
prob = StanProblem(post, ".", force=true, make_args=["STAN_THREADS=true"])

# Set parameters
n_adapts = 1000
n_samples = 1000
n_chains = 4
n_tries = 100

# STAN DEFAULT
# rng = Random.default_rng();
rng = Random.MersenneTwister(1234);
res1=run_scenario(StanHMCAdaptor,WelfordVar, rng, prob; n_samples, n_adapts,n_chains,verbose=false);

# NUTPIE Variants
# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs
rng = Random.MersenneTwister(1234);
res2=run_scenario(NutpieHMCAdaptor,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res3=run_scenario(NutpieHMCAdaptorNoGradInit,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res4=run_scenario(NutpieHMCAdaptorNoSwitch,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res5=run_scenario(NutpieHMCAdaptorNoSwitchNoGradInit,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);

# Plot it
function plot_bar(res_array,metric,labels=["Stan-like","Nutpie","Nutpie no grad init","Nutpie no switch","Nutpie no switch no grad init"];
    kwargs...)
    pl = bar(labels,getfield.(res_array,metric);xrotation=15,left_margin=5Plots.mm,bottom_margin=5Plots.mm,kwargs...)
end

pl = let res_array = [res1,res2,res3,res4,res5]
    plot(
        plot_bar(res_array,:ess_per_grad_mean; title="ESS per gradient evaluation",
            ylabel="ESS/grad eval",legend=false),
        plot_bar(res_array,:ess_mean; title="ESS (mean)",
            ylabel="ESS (mean)",legend=false), 
        plot_bar(res_array,:num_grads_evals; title="# of Gradient Eval.",
            ylabel="Grad. evaluations",legend=false),
        plot_bar(res_array,:rhat_max; title="Max Rhat",
            ylabel="Rhat (max)",legend=false),
        plot_bar(res_array,:divergences; title="# of Divergences",
            ylabel="Divergences",legend=false),
        plot_bar(res_array,:time; title="Time Elapsed",
            ylabel="Time (s)",legend=false),
            size=(900,600),layout=(2,3),titlefontsize=12)
end

# savefig(pl,"diamonds_20230212.png")
aseyboldt commented 1 year ago

I don't know for sure, bad step size adaptation might explain what you are seeing. If the mass matrix itself is good, but the final step size doesn't match the mass matrix, you might see lot's of divergences. If that's the problem, the actual mean acceptance rate after tuning would not match the target.

And shouldn't this be a final step size adaptation, instead of mass matrix adaptation? https://github.com/TuringLang/AdvancedHMC.jl/pull/312/files#diff-622665792b73235f2c5b58233a3eb82abaea68b4133c51bf9471d7bac99c10d3R118

svilupp commented 1 year ago

I don't know for sure, bad step size adaptation might explain what you are seeing. If the mass matrix itself is good, but the final step size doesn't match the mass matrix, you might see lot's of divergences. If that's the problem, the actual mean acceptance rate after tuning would not match the target.

Good tip! Thanks - I'll look into it. My assumption so far is that the mass matrix is bad, so it gives too much kinetic energy (hence, the divergences), but I'll look out for the step sizes too!

And shouldn't this be a final step size adaptation, instead of mass matrix adaptation? https://github.com/TuringLang/AdvancedHMC.jl/pull/312/files#diff-622665792b73235f2c5b58233a3eb82abaea68b4133c51bf9471d7bac99c10d3R118

I've removed these files, they were just background artefacts. The one you references is a chatGPT re-write of your codebase - I think the line you referenced is this one.

The actual implementation is here, where it first updates the step size (same as default in this package) and then we update the mass_matrix as per Nutpie here (it's decoupled from the switch! which happens 4 lines earlier).

This adapt! calls the adaptation method of the variance accumulator, which is just a bundle of 4 Welford variance accumulators (WelfordVar) - two for draws, two for grads. The different lingo here is that push! methods add new samples to the accumulators, update! updates the variance estimator in adaptor.exp_variance_draw.var (I tried to adhere a bit to your naming) here.