cscherrer / SampleChainsDynamicHMC.jl

MIT License
4 stars 2 forks source link

Soss drawing poor samples from simple posterior #12

Closed sethaxen closed 3 years ago

sethaxen commented 3 years ago

Soss is failing to target the correct (analytically known) posterior in the following simple example using SampleChainsDynamicHMC. I ruled out any (obvious) bugs in MeasureTheory and DynamicHMC by using the two of them directly to sample from the same posterior.

# setup
using Pkg
Pkg.add([
    Pkg.PackageSpec(name="Soss", version="0.20"),
    Pkg.PackageSpec(name="SampleChainsDynamicHMC", version="0.3"),
    Pkg.PackageSpec(name="MeasureTheory", version="0.9"),
    Pkg.PackageSpec(name="TransformVariables", version="0.4"),
    Pkg.PackageSpec(name="LogDensityProblems", version="0.10"),
    Pkg.PackageSpec(name="DynamicHMC", version="3"),
    Pkg.PackageSpec(name="StatsPlots", version="0.14"),
])

using Soss, MeasureTheory, Random

# exact posterior is μ ~ Normal(μ=y/2, σ=inv(sqrt(2)))
mod = Soss.@model begin
    μ ~ Normal(μ=0, σ=1)
    y ~ Normal(μ=μ, σ=1)
end
y = 0.2

# draw samples using Soss
using SampleChainsDynamicHMC
rng = MersenneTwister(42)
chains = Soss.sample(rng, mod() | (; y=y), dynamichmc(), 1_000, 4)
post_soss = [collect(chain.μ) for chain in getchains(chains)]

# draw samples using DynamicHMC directly
using TransformVariables, LogDensityProblems, DynamicHMC
struct TestProblem{T}
    y::T
end
function (problem::TestProblem)(θ)
    μ = θ.μ
    logdensity(Normal(μ=0, σ=1), μ) + logdensity(Normal(μ=μ, σ=1), problem.y)
end
p = TestProblem(y)
trans = as((μ = asℝ,))
P = TransformedLogDensity(trans, p)
∇P = ADgradient(:ForwardDiff, P)
rng = MersenneTwister(87)
post_dhmc = [first.(mcmc_with_warmup(rng, ∇P, 1000).chain) for _ in 1:4]

# draw samples from exact posterior
rng = MersenneTwister(7)
post_exact = [randn(rng, 1000) ./ sqrt(2) .+ y/2 for _ in 1:4]

# plot ECDFs
using StatsPlots
plot()
for (i, post) in enumerate(post_exact)
    ecdfplot!(post; label=i == 1 ? "Exact" : "", primary=i==1)
end
for (i, post) in enumerate(post_soss)
    ecdfplot!(post; label=i == 1 ? "Soss" : "", primary=i==1)
end
for (i, post) in enumerate(post_dhmc)
    ecdfplot!(post; label=i == 1 ? "DynamicHMC" : "", primary=i==1)
end
plot!()

tmp

cscherrer commented 3 years ago

Ouch, thanks for letting me know about this. This is high priority, obviously. I'll dig into it and let you know what I find.

cscherrer commented 3 years ago

Since the code it's calling is now so simple,

function sample(rng::AbstractRNG, 
    m::ConditionalModel,
    config::DynamicHMCConfig, 
    nsamples::Int=1000,
    nchains::Int=4)

    ℓ(x) = Soss.logdensity(m, x)
    tr = xform(m)

    chains = newchain(rng, nchains, config, ℓ, tr)
    sample!(chains, nsamples - 1)
    return chains
end

I wonder if the problem might be in SampleChainsDynamicHMC

cscherrer commented 3 years ago

Soss logdensity looks ok:

xx = range(-10,10,length=10000)
dens = [MeasureTheory.density(mod() | (y=0.2,), (μ=x,)) for x in xx]
ecdf = cumsum(dens)
ecdf ./= last(ecdf)

plot!(xx,ecdf, label="normalized cumsum of Soss density evaluations")
xlims!(-2,3)

cumsum

cscherrer commented 3 years ago

In DynamicHMC here, mcmc_next_step is used like this:

    for i in 1:N
        Q, tree_statistics[i] = mcmc_next_step(steps, Q)
        chain[i] = Q.q
        report(mcmc_reporter, i)
    end

But SampleChainsDynamicHMC is currently using it like this:

function SampleChains.step!(chain::DynamicHMCChain)
    Q, tree_stats = DynamicHMC.mcmc_next_step(getfield(chain, :meta), getfield(chain, :state))
end

So I think the problem might be that we're failing to update the state.

cscherrer commented 3 years ago

Think I got it fixed

sethaxen commented 3 years ago

What was the problem?

cscherrer commented 3 years ago

GitHub is collapsing it, but https://github.com/cscherrer/SampleChainsDynamicHMC.jl/issues/12#issuecomment-880293340

I had to make the state mutable. Using a zero-dimensional array just to try it out, but I haven't benchmarked it vs other approaches.