Open femtomc opened 4 years ago
The thing which is on our side is that the kernels constitute a Markov transition if you add the sample time to the state, so transitions (t1, x1, v1) -> (t2, x2, v2) -> (t3, x3, v3)
constitute a rejection free Markov chain with time t
, position x
and momentum v
. The accumulation of the posterior samples from the (ti, xi)
's is a bit more difficult, but that should not make the one-step kernel more complicated.
using Jaynes: UnconstrainedSelection, CallSite
function boomerang(sel::K, cl::C) where {K <: UnconstrainedSelection, C <: CallSite}
p_mod_score = get_score(cl)
sel_values, choice_grads = get_choice_gradients(sel, cl, 1.0)
sel_values_ref = Ref(sel_values)
cl_ref = Ref(cl)
x = array(sel_values, Float64)
d = length(x)
Flow = Boomerang(sparse(I, d, d), zeros(d), 1.0)
N = MvNormal(d, 1.0)
θ = rand(N)
t = 0.0
∇ϕx = copy(θ)
acc = num = 0
function ∇ϕ!(y, x, sel, cl_ref, sel_values_ref)
sel_values_ref[] = selection(sel_values_ref[], x)
cl_ref[] = update(sel_values_ref[], cl_ref[])[2]
sel_values_ref[], choice_grads = get_choice_gradients(sel, cl_ref[], 1.0)
y .= array(choice_grads, Float64)
end
Ξ = ZigZagBoomerang.Trace(t, x, θ, Flow) # should persist between calls
τref = T = ZigZagBoomerang.waiting_time_ref(Flow)
c = 100.
a, b = ZigZagBoomerang.ab(x, θ, c, Flow)
t′ = t + poisson_time(a, b)
while t < T
t, x, θ, (acc, num), c, a, b, t′, τref = ZigZagBoomerang.pdmp_inner!(Ξ, ∇ϕ!, ∇ϕx, t, x, θ, c, a, b, t′, τref, (acc, num), Flow, sel, cl_ref, sel_values_ref; adapt=false)
end
cl_ref[]
end
The part where I send Ref
s for sel_values
and cl_ref
through pdmp_inner!
is a bit clumsy, maybe you can clarify how to get the gradient of the model for a current vector x
from the model
Ref https://github.com/femtomc/Jaynes.jl/blob/master/src/inference/boo.jl for Boomerang
I also translated this to Gen.jl (mostly to make myself familiar) https://github.com/mschauer/ZigZagBoomerang.jl/blob/master/scripts/genpdmp.jl
I'd love to figure out a way to utilize your kernels. This is part of a larger conversation, about how trace-based PPLs can use optimized kernel libraries (or inherit from AbstractMCMC).
For now, I'd like to try and write a one-step kernel similar to https://github.com/femtomc/Jaynes.jl/blob/master/src/inference/hmc.jl which calls your kernels. This will the reference issue for this work.