mschauer / ZigZagBoomerang.jl

Sleek implementations of the ZigZag, Boomerang and other assorted piecewise deterministic Markov processes for Markov Chain Monte Carlo including Sticky PDMPs for variable selection
MIT License
101 stars 7 forks source link

Integration with trace-based PPLs #34

Open femtomc opened 4 years ago

femtomc commented 4 years ago

I'd love to figure out a way to utilize your kernels. This is part of a larger conversation, about how trace-based PPLs can use optimized kernel libraries (or inherit from AbstractMCMC).

For now, I'd like to try and write a one-step kernel similar to https://github.com/femtomc/Jaynes.jl/blob/master/src/inference/hmc.jl which calls your kernels. This will the reference issue for this work.

mschauer commented 4 years ago

The thing which is on our side is that the kernels constitute a Markov transition if you add the sample time to the state, so transitions (t1, x1, v1) -> (t2, x2, v2) -> (t3, x3, v3) constitute a rejection free Markov chain with time t, position x and momentum v. The accumulation of the posterior samples from the (ti, xi)'s is a bit more difficult, but that should not make the one-step kernel more complicated.

mschauer commented 4 years ago
using Jaynes: UnconstrainedSelection, CallSite

function boomerang(sel::K, cl::C) where {K <: UnconstrainedSelection, C <: CallSite}

    p_mod_score = get_score(cl)
    sel_values, choice_grads = get_choice_gradients(sel, cl, 1.0)
    sel_values_ref = Ref(sel_values)
    cl_ref = Ref(cl)

    x = array(sel_values, Float64)
    d = length(x)
    Flow = Boomerang(sparse(I, d, d), zeros(d), 1.0)
    N = MvNormal(d, 1.0)
    θ = rand(N)
    t = 0.0
    ∇ϕx = copy(θ)
    acc = num = 0
    function ∇ϕ!(y, x, sel, cl_ref, sel_values_ref)
        sel_values_ref[] = selection(sel_values_ref[], x)
        cl_ref[] = update(sel_values_ref[], cl_ref[])[2]
        sel_values_ref[], choice_grads = get_choice_gradients(sel, cl_ref[], 1.0)
        y .= array(choice_grads, Float64)
    end

    Ξ = ZigZagBoomerang.Trace(t, x, θ, Flow) # should persist between calls
    τref = T = ZigZagBoomerang.waiting_time_ref(Flow)
    c = 100.
    a, b = ZigZagBoomerang.ab(x, θ, c, Flow)
    t′ = t + poisson_time(a, b)
    while t < T
        t, x, θ, (acc, num), c, a, b, t′, τref = ZigZagBoomerang.pdmp_inner!(Ξ, ∇ϕ!, ∇ϕx, t, x, θ, c, a, b, t′, τref, (acc, num), Flow, sel, cl_ref, sel_values_ref; adapt=false)
    end

    cl_ref[]
end

The part where I send Refs for sel_values and cl_ref through pdmp_inner! is a bit clumsy, maybe you can clarify how to get the gradient of the model for a current vector x from the model

mschauer commented 4 years ago

Ref https://github.com/femtomc/Jaynes.jl/blob/master/src/inference/boo.jl for Boomerang

mschauer commented 4 years ago

I also translated this to Gen.jl (mostly to make myself familiar) https://github.com/mschauer/ZigZagBoomerang.jl/blob/master/scripts/genpdmp.jl