TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2k stars 215 forks source link

Progress reporting in parallel sampling #2264

Open SamuelBrand1 opened 2 weeks ago

SamuelBrand1 commented 2 weeks ago

Hi everyone,

One thing I've noticed is that progress reporting when doing chains in parallel (for example using MCMCThreads()) is not informative, the progress meter only updates when a chain is finished rather than reporting within chain progress (as per serial sampling).

Is there any movement towards chain-by-chain progress reporting as per stan?

torfjelde commented 2 weeks ago

This is an issue that has come up fairly often but AFAIK no perfect solution exists.

Ref: https://github.com/TuringLang/AbstractMCMC.jl/issues/82 https://github.com/TuringLang/AbstractMCMC.jl/issues/105

There's a discourse thread where someone seems to have come up with a "solution" (https://discourse.julialang.org/t/displaying-parallel-progress-bars/4148/8), but that's ages ago and not sure if that solution still works.

Note that you can provide an arbitrary callback to sample which is executed after every step where you could so custom progress-keeping, but atm there's no good built-in solution unfortunately :confused:

SamuelBrand1 commented 2 weeks ago

Thanks for flagging this up @torfjelde ! I guess this will keep circling around :-(.

torfjelde commented 2 weeks ago

Might be possible to do something with this: https://github.com/timholy/ProgressMeter.jl/pull/157

In fact, if I use that branch + some minor changes to AbstractMCMC.jl, the following

using ProgressMeter
using Turing

struct ProgressCallback{P}
    p::P
    index::Int
end

function (callback::ProgressCallback)(rng, model, sampler, sample, state, iteration; kwargs...)
    # Can do more stuff here if you want.
    next!(callback.p[callback.index])
end

@model demo() = x ~ Normal()
model = demo()

num_samples = 100_000
num_chains = 10
p = MultipleProgress(
    [Progress(num_samples; desc="Chain $i ") for i in 1:num_chains],
    Progress(num_samples * num_chains; desc="Total ")
)
callbacks = map(1:num_chains) do i
    ProgressCallback(p, i)
end
chain = sample(
    model,
    HMC(0.1, 32),
    MCMCThreads(),
    num_samples,
    num_chains,
    callback=callbacks,
    progress=false,
    thinning=10
)

results in

image

It's small so not sure if you can see it, but it creates one bar for each thread + a global progress bar.

(note that this relies on minor changes to abstractmcmc + that experimental branch of progressmeter, which only supports the REPL, not, say, IJulia)

Miiight be worth adopting this in TuringCallbacks.jl as a bridge until there's good solution.