mschauer / ZigZagBoomerang.jl

Sleek implementations of the ZigZag, Boomerang and other assorted piecewise deterministic Markov processes for Markov Chain Monte Carlo including Sticky PDMPs for variable selection
MIT License
100 stars 7 forks source link

Take non-equidistant samples and adapt mass matrix to Fisher information #114

Closed mschauer closed 2 years ago

mschauer commented 2 years ago

Adresses #113 . Supercedes #112 .

mschauer commented 2 years ago

This is how it would work with your iterator, @cscherrer

M = PDMats.PDiagMat(ones(d))

Z = BouncyParticle(missing, # graphical structure 
    missing, # MAP estimate, unused
    rate, # momentum refreshment rate and sample saving rate 
    1-1/n, # momentum correlation / only gradually change momentum in refreshment/momentum update
    M, # metric
    missing # legacy
) 

and then the collect_sampler gets changed into

function collect_sampler2(t, sampler, n; adapt_mass=true, progress=true, progress_stops=20, ra_offset=0)
    if progress
        prg = Progress(progress_stops, 1)
    else
        prg = missing
    end
    stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
    nstop = n/stops
    d = length(sampler.u0[2][1])

    x1 = t(sampler.u0[2][1])
    tv = chainvec(x1, n)
    ϕ = iterate(sampler)
    j = 1
    local state
    M = sampler.F.U
    if adapt_mass
       m = 1 ./ M.diag
    end
    while ϕ !== nothing && j < n
        j += 1
        val, state = ϕ
        tv[j] = t(val[2])
        if adapt_mass
            @. m =  m + (state[2]^2 - m)/(ra_offset + j-1) # running average shifted by offset
            state = ZZB.set_action(state, :invalid)
            v = state[1][2][2] # get velocity
            PDMats.whiten!(M, v)
            @. M.diag = 1/m
            PDMats.unwhiten!(M, v)
        end
        ϕ = iterate(sampler, state)
        if j > nstop
            nstop += n/stops
            next!(prg) 
        end 
    end
    ismissing(prg) || ProgressMeter.finish!(prg)
    tv, (;uT=state[1], acc=state[3][1], total=state[3][2], bound=state[4].c)
end