TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
87 stars 18 forks source link

Use nchains instead of Threads.nthreads() #38

Closed cpfiffer closed 4 years ago

cpfiffer commented 4 years ago

Fixes #37.

codecov[bot] commented 4 years ago

Codecov Report

Merging #38 into master will increase coverage by 0.02%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #38      +/-   ##
==========================================
+ Coverage   97.52%   97.54%   +0.02%     
==========================================
  Files           5        5              
  Lines         121      122       +1     
==========================================
+ Hits          118      119       +1     
  Misses          3        3              
Impacted Files Coverage Δ
src/sample.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update bd41004...8f01238. Read the comment docs.

devmotion commented 4 years ago

I think this is not correct, and actually I don't think #37 can be fixed by us (and hence I think it's the user's responsibility to ensure the number of chains is at least as large as the number of threads). The problem is that the scheduler is non-deterministic and we don't know which threads will be used, and in particular it's not guaranteed that the first nchains threads will be used. However, we use the ids of the threads to select the rngs, models, and sampler, so this will fail randomly.

xukai92 commented 4 years ago

Can we just move those copy calls inside the loop?

devmotion commented 4 years ago

That would be inefficient in the (I guess more common and intended) case where nchains is larger than Threads.nthreads(), since the model, sampler, and random number generator would be copied in every iteration instead of just once per thread.

devmotion commented 4 years ago

IMO it is the user's responsibility to ensure that nchains is larger than Threads.nthreads() if she wants to avoid unneeded copies.

xukai92 commented 4 years ago

That would be inefficient ...

Good point. I think we can check we are in which case (threads > chians or the other way) and do what I proposed and what is currently implement based on the case.

IMO it is the user's responsibility to ensure ...

I don't think it's possible in a lot of cases. I need this issue in a case I need to use an optimisation library that I need 20 threads and also in the loop I want to do parallel sampling with 5 chains. I found that when I increase my number of threads the Turing part became slow and found the issue I created yesterday.

Furthermore, another reason I think we should fix it is for nested multi-threading - what if I want to run 5 parallel chains where I also want to use multi-threading in the model call to parallize observations? The user will have to use more threads than the number of chains.

devmotion commented 4 years ago

I need this issue in a case I need to use an optimisation library that I need 20 threads and also in the loop I want to do parallel sampling with 5 chains. I found that when I increase my number of threads the Turing part became slow and found the issue I created yesterday.

For this use case I guess it is more efficient to just not use the default parallel sampling implementation in AbstractMCMC. It is trivial to spawn nchains calls of sample in your implementation and combine the results with chainscat(fetch(task1), fetch(task2), ...). Then you could just copy the model, the sampler, and the RNGs outside of your optimization loop nchains times, which should let you avoid the overhead by copying almost completely.

So I'm still not convinced that the we should add handle this use case in the default implementation in AbstractMCMC.

cpfiffer commented 4 years ago

I think it's not correct that we can't fix this. We should be allocating everything on the number of chains regardless of how many threads there are, and indexing on the chain number. Our current implementation is incorrect in that we use thread ID instead of chain number.

This loop, for example, should be changed from

                Threads.@threads for i in 1:nchains
                    # Obtain the ID of the current thread.
                    id = Threads.threadid()

                    # Seed the thread-specific random number generator with the pre-made seed.
                    subrng = rngs[id]
                    Random.seed!(subrng, seeds[i])

                    # Sample a chain and save it to the vector.
                    chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
                                                 progress = false, kwargs...)

                    # Update the progress bar.
                    progress && put!(channel, true)
                end

to

                Threads.@threads for i in 1:nchains
                    # Seed the thread-specific random number generator with the pre-made seed.
                    subrng = rngs[i]
                    Random.seed!(subrng, seeds[i])

                    # Sample a chain and save it to the vector.
                    chains[i] = StatsBase.sample(subrng, models[i], samplers[i], N;
                                                 progress = false, kwargs...)

                    # Update the progress bar.
                    progress && put!(channel, true)
                end

Changing this would make the code completely independent of the number of threads (which we can't control) and predetermines everything conditional on the number of chains (which is a strict value known ahead of time). The thread scheduler will then handle nchains < nthreads and vice versa with zero extra thought from either us or the user.

devmotion commented 4 years ago

The current implementation is correct but optimized for the case where Threads.nthreads() < nchains (see https://github.com/TuringLang/AbstractMCMC.jl/pull/38#issuecomment-629632591). Your alternative (which is what @xukai92 would like us to use if nchains < Threads.nthreads()) is inefficient in this case since many more copies of sampler, model, and rng are created than needed. I still think it is reasonable to assume Threads.nthreads() < nchains by default, but we could think about the alternative if that assumption is not satisfied.

But as explained above, IMO users probably should never use the default implementation for parallel sampling inside of loops since it will create a lot of unneeded copies (regardless of the algorithm in AbstractMCMC).

devmotion commented 4 years ago

BTW having one copy (of the RNG or some other mutable structure) per thread is actually what the Julia documentation suggests.

cpfiffer commented 4 years ago

I'm not really concerned with making too many copies of sampler, model, and rng, since they have relatively small footprints. Anyone who is doing multithreading is probably fine with making the memory/parallelism tradeoff, since that is an explicit one you make when doing any kind of parallelism.

I think it's just not worth the complexity to change behavior on the number of threads in relation to the number of chains. Just provide all the resources ahead of time and let the scheduler work it out.

devmotion commented 4 years ago

My impression was that @xukai92 experienced a slowdown due to the additional copies that were created but not not needed, and therefore opened the issue. Hence it seems completely switching to the alternative would lead to a slowdown as soon as nchains > Threads.nthreads().

cpfiffer commented 4 years ago

Aight, far enough. I'll close this for now.

xukai92 commented 4 years ago

Hence it seems completely switching to the alternative would lead to a slowdown as soon as

What I proposed in https://github.com/TuringLang/AbstractMCMC.jl/pull/38#issuecomment-629650848 was that we can simply check which case 1) nchains > Threads.nthreads() or 2) Threads.nthreads() > nchains we are actually in, and apply the correct way of assigning things.

I'm still not sure why we don't want to do this given that it's an easy change from our side and it can save the user some trouble.

cpfiffer commented 4 years ago

Okay, changed to use min(nchains, Threads.nthreads()).

devmotion commented 4 years ago

That still introduces the same bug if nchains < Threads.nthreads(). As discussed above, if we want to change something about it, we have to handle nchains < Threads.nthreads() and nchains >= Threads.nthreads() in a different way - in the former case copy should be moved inside the loop whereas in the latter case it should be done outside of the loop for all threads once, for efficiency reasons.

I would prefer not to break the current state since that works efficiently for all cases where nchains >= Threads.nthreads().

I'm not strictly against covering the other case as well, I'm just wondering if it's worth the increased code complexity. At least in the use case mentioned above there are better fixes that avoid this problem.

xukai92 commented 4 years ago

I'm not strictly against covering the other case as well, I'm just wondering if it's worth the increased code complexity. At least in the use case mentioned above there are better fixes that avoid this problem.

If we don't want to make the change, we should at least through a warning here and guide the user to write their own loop. Also, currenly when nchains < Threads.nthreads(), the progress bar says running parallel chains with Threads.nthreads(), which is also not true and confusing.

But again, I think it's good for us to take care of this. + writing the customized loop is not trivial when I also want to implement the progress meter correct :sweat_smile:

devmotion commented 4 years ago

But again, I think it's good for us to take care of this. + writing the customized loop is not trivial when I also want to implement the progress meter correct

I agree, with a progress bar it becomes a bit more complicated. I didn't think that you would care about it, in my experience it can slow down sampling quite a bit so I assumed you disabled it.

devmotion commented 4 years ago

So, I digged a bit deeper into the Julia internals and it seems the current undocumented internal behaviour of @threads is actually to split the loop in batches of equal size among all threads, starting with thread 1. Since this is not part of the documentation, I guess one is supposed to not rely on this behaviour and it might change at any time. However, I think we could still go with the min(nchains, Threads.nthreads()) workaround for now since it is the simplest solution and seems to work with the current Julia version at least.

I just suggest to