TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
87 stars 18 forks source link

More useful logging when sampling multiple chains #82

Open sethaxen opened 3 years ago

sethaxen commented 3 years ago

If I sample multiple chains, the progress bar shown in the terminal only updates when a chain finishes sampling. This perhaps makes sense if I am quickly sampling many chains in parallel, but this is essentially useless for the common case of a user sampling 2-8 chains of medium to long runtimes in parallel, since as a user I don't have an estimate if the chains will finish sampling in 4 minutes or 4 days.

Currently one works around this by sampling a single chain "long enough" to guess how long it will take to finish, killing the sampler, and then restarting with multiple chains, but it would be better if multiple progress bars were shown, one for each chain, perhaps with a global progress bar showing how many chains have finished. On Slack a while ago, someone proposed that if a user sampled more chains than progress bars would fit in the terminal, then it might be better to use something like the animation Julia uses when compiling packages.

Are the limitations to doing something like this in this package or one of its dependencies?

devmotion commented 3 years ago

This is a longstanding issue and should be improved. The main reason for why currently not the progress of individual samples is shown is that it is a bit challenging to do it right: it is expected that users only have to implement the sample method for a single chain and then the multi-chain methods work automatically. Ie, we always want to call the single chain method in the multi chain methods.

Some explanation of the progress logging in general (copied from Slack):

I assume that Pluto uses a special logger or handles logging messages in a different way than what is done in the REPL. Maybe logging messages are dropped completely? We use the frontend agnostic ProgressLogging.jl for progress logging (also used eg by SciML). It uses the logging framework in Julia and works in the REPL (with TerminalLoggers), Juno, VSCode (both use ProgressLogging for progress bars anyway) and IJulia. We also make sure that there's a logger that can handle the messages, and add a local logger for the progress logs otherwise which is optimized for the environment that code is run in (IIRC for IJulia and Windows we use ConsoleProgressMonitor instead of TerminalLogger).

devmotion commented 3 years ago

Related: https://github.com/JuliaLogging/ProgressLogging.jl/issues/27, https://github.com/JuliaLogging/ProgressLogging.jl/issues/33, https://github.com/JuliaLogging/ProgressLogging.jl/issues/37, https://github.com/JuliaLogging/ProgressLogging.jl/issues/38

kaandocal commented 3 years ago

From what I understand the parallel mcmcsample implementations call the serial mcmcsample in each thread/process, so maybe one could hook each sampler to a callback which updates the corresponding progress bar...

devmotion commented 3 years ago

I thought about injecting a callback, and maybe it's the best solution. But there's a major problem: we don't call mcmcsample but sample since users/developers don't have to use the default fallback but can write their own sample function (and then it should just work to also sample multiple chains in exactly the same way). Thus we would have to mess around with possible user provided callbacks which might affect their dispatches and requires that users propagate the callbacks and don't modify them in their sample procedure. And they might not even support callbacks in a custom implementation 🤷‍♂️

kaandocal commented 3 years ago

I see! It might be a good idea to provide a generic threaded/distributed dispatching mechanism (extending the mcmcsample implementation) so a user basically only has to write a serial sample method. One could either require a callback keyword argument, or only if the default dispatcher (which would use callbacks to update progress bars) is run with progress=true. I'm trying my hand at a demo right now to see how this could work...

devmotion commented 3 years ago

That's exactly what's done currently, but maybe you mean something else? You just have to write a sample method for a single chain or implement the next layer interface if you want to use the mcmcsample fallback, and then you'll get sampling of multiple chains (with multiple threads etc) for free.

kaandocal commented 3 years ago

Seems like I misinterpreted your comment then, you're right, the distributed mcmcsample implementations call StatsBase.sample.