Open SamuelBrand1 opened 5 months ago
This is an issue that has come up fairly often but AFAIK no perfect solution exists.
Ref: https://github.com/TuringLang/AbstractMCMC.jl/issues/82 https://github.com/TuringLang/AbstractMCMC.jl/issues/105
There's a discourse thread where someone seems to have come up with a "solution" (https://discourse.julialang.org/t/displaying-parallel-progress-bars/4148/8), but that's ages ago and not sure if that solution still works.
Note that you can provide an arbitrary callback
to sample
which is executed after every step
where you could so custom progress-keeping, but atm there's no good built-in solution unfortunately :confused:
Thanks for flagging this up @torfjelde ! I guess this will keep circling around :-(.
Might be possible to do something with this: https://github.com/timholy/ProgressMeter.jl/pull/157
In fact, if I use that branch + some minor changes to AbstractMCMC.jl, the following
using ProgressMeter
using Turing
struct ProgressCallback{P}
p::P
index::Int
end
function (callback::ProgressCallback)(rng, model, sampler, sample, state, iteration; kwargs...)
# Can do more stuff here if you want.
next!(callback.p[callback.index])
end
@model demo() = x ~ Normal()
model = demo()
num_samples = 100_000
num_chains = 10
p = MultipleProgress(
[Progress(num_samples; desc="Chain $i ") for i in 1:num_chains],
Progress(num_samples * num_chains; desc="Total ")
)
callbacks = map(1:num_chains) do i
ProgressCallback(p, i)
end
chain = sample(
model,
HMC(0.1, 32),
MCMCThreads(),
num_samples,
num_chains,
callback=callbacks,
progress=false,
thinning=10
)
results in
It's small so not sure if you can see it, but it creates one bar for each thread + a global progress bar.
(note that this relies on minor changes to abstractmcmc + that experimental branch of progressmeter, which only supports the REPL, not, say, IJulia)
Miiight be worth adopting this in TuringCallbacks.jl as a bridge until there's good solution.
Hi everyone,
One thing I've noticed is that progress reporting when doing chains in parallel (for example using
MCMCThreads()
) is not informative, the progress meter only updates when a chain is finished rather than reporting within chain progress (as per serial sampling).Is there any movement towards chain-by-chain progress reporting as per stan?