Closed mschauer closed 2 years ago
This is how it would work with your iterator, @cscherrer
M = PDMats.PDiagMat(ones(d))
Z = BouncyParticle(missing, # graphical structure
missing, # MAP estimate, unused
rate, # momentum refreshment rate and sample saving rate
1-1/n, # momentum correlation / only gradually change momentum in refreshment/momentum update
M, # metric
missing # legacy
)
and then the collect_sampler
gets changed into
function collect_sampler2(t, sampler, n; adapt_mass=true, progress=true, progress_stops=20, ra_offset=0)
if progress
prg = Progress(progress_stops, 1)
else
prg = missing
end
stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
nstop = n/stops
d = length(sampler.u0[2][1])
x1 = t(sampler.u0[2][1])
tv = chainvec(x1, n)
ϕ = iterate(sampler)
j = 1
local state
M = sampler.F.U
if adapt_mass
m = 1 ./ M.diag
end
while ϕ !== nothing && j < n
j += 1
val, state = ϕ
tv[j] = t(val[2])
if adapt_mass
@. m = m + (state[2]^2 - m)/(ra_offset + j-1) # running average shifted by offset
state = ZZB.set_action(state, :invalid)
v = state[1][2][2] # get velocity
PDMats.whiten!(M, v)
@. M.diag = 1/m
PDMats.unwhiten!(M, v)
end
ϕ = iterate(sampler, state)
if j > nstop
nstop += n/stops
next!(prg)
end
end
ismissing(prg) || ProgressMeter.finish!(prg)
tv, (;uT=state[1], acc=state[3][1], total=state[3][2], bound=state[4].c)
end
Adresses #113 . Supercedes #112 .