Closed cpfiffer closed 4 years ago
Merging #38 into master will increase coverage by
0.02%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## master #38 +/- ##
==========================================
+ Coverage 97.52% 97.54% +0.02%
==========================================
Files 5 5
Lines 121 122 +1
==========================================
+ Hits 118 119 +1
Misses 3 3
Impacted Files | Coverage Δ | |
---|---|---|
src/sample.jl | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update bd41004...8f01238. Read the comment docs.
I think this is not correct, and actually I don't think #37 can be fixed by us (and hence I think it's the user's responsibility to ensure the number of chains is at least as large as the number of threads). The problem is that the scheduler is non-deterministic and we don't know which threads will be used, and in particular it's not guaranteed that the first nchains threads will be used. However, we use the ids of the threads to select the rngs, models, and sampler, so this will fail randomly.
Can we just move those copy calls inside the loop?
That would be inefficient in the (I guess more common and intended) case where nchains
is larger than Threads.nthreads()
, since the model, sampler, and random number generator would be copied in every iteration instead of just once per thread.
IMO it is the user's responsibility to ensure that nchains
is larger than Threads.nthreads()
if she wants to avoid unneeded copies.
That would be inefficient ...
Good point. I think we can check we are in which case (threads > chians or the other way) and do what I proposed and what is currently implement based on the case.
IMO it is the user's responsibility to ensure ...
I don't think it's possible in a lot of cases. I need this issue in a case I need to use an optimisation library that I need 20 threads and also in the loop I want to do parallel sampling with 5 chains. I found that when I increase my number of threads the Turing part became slow and found the issue I created yesterday.
Furthermore, another reason I think we should fix it is for nested multi-threading - what if I want to run 5 parallel chains where I also want to use multi-threading in the model call to parallize observations? The user will have to use more threads than the number of chains.
I need this issue in a case I need to use an optimisation library that I need 20 threads and also in the loop I want to do parallel sampling with 5 chains. I found that when I increase my number of threads the Turing part became slow and found the issue I created yesterday.
For this use case I guess it is more efficient to just not use the default parallel sampling implementation in AbstractMCMC. It is trivial to spawn nchains
calls of sample
in your implementation and combine the results with chainscat(fetch(task1), fetch(task2), ...)
. Then you could just copy the model, the sampler, and the RNGs outside of your optimization loop nchains
times, which should let you avoid the overhead by copying almost completely.
So I'm still not convinced that the we should add handle this use case in the default implementation in AbstractMCMC.
I think it's not correct that we can't fix this. We should be allocating everything on the number of chains regardless of how many threads there are, and indexing on the chain number. Our current implementation is incorrect in that we use thread ID instead of chain number.
This loop, for example, should be changed from
Threads.@threads for i in 1:nchains
# Obtain the ID of the current thread.
id = Threads.threadid()
# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
Random.seed!(subrng, seeds[i])
# Sample a chain and save it to the vector.
chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
progress = false, kwargs...)
# Update the progress bar.
progress && put!(channel, true)
end
to
Threads.@threads for i in 1:nchains
# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[i]
Random.seed!(subrng, seeds[i])
# Sample a chain and save it to the vector.
chains[i] = StatsBase.sample(subrng, models[i], samplers[i], N;
progress = false, kwargs...)
# Update the progress bar.
progress && put!(channel, true)
end
Changing this would make the code completely independent of the number of threads (which we can't control) and predetermines everything conditional on the number of chains (which is a strict value known ahead of time). The thread scheduler will then handle nchains < nthreads
and vice versa with zero extra thought from either us or the user.
The current implementation is correct but optimized for the case where Threads.nthreads() < nchains
(see https://github.com/TuringLang/AbstractMCMC.jl/pull/38#issuecomment-629632591). Your alternative (which is what @xukai92 would like us to use if nchains < Threads.nthreads()
) is inefficient in this case since many more copies of sampler
, model
, and rng
are created than needed. I still think it is reasonable to assume Threads.nthreads() < nchains
by default, but we could think about the alternative if that assumption is not satisfied.
But as explained above, IMO users probably should never use the default implementation for parallel sampling inside of loops since it will create a lot of unneeded copies (regardless of the algorithm in AbstractMCMC).
BTW having one copy (of the RNG or some other mutable structure) per thread is actually what the Julia documentation suggests.
I'm not really concerned with making too many copies of sampler
, model
, and rng
, since they have relatively small footprints. Anyone who is doing multithreading is probably fine with making the memory/parallelism tradeoff, since that is an explicit one you make when doing any kind of parallelism.
I think it's just not worth the complexity to change behavior on the number of threads in relation to the number of chains. Just provide all the resources ahead of time and let the scheduler work it out.
My impression was that @xukai92 experienced a slowdown due to the additional copies that were created but not not needed, and therefore opened the issue. Hence it seems completely switching to the alternative would lead to a slowdown as soon as nchains > Threads.nthreads()
.
Aight, far enough. I'll close this for now.
Hence it seems completely switching to the alternative would lead to a slowdown as soon as
What I proposed in https://github.com/TuringLang/AbstractMCMC.jl/pull/38#issuecomment-629650848 was that we can simply check which case 1) nchains > Threads.nthreads()
or 2) Threads.nthreads() > nchains
we are actually in, and apply the correct way of assigning things.
I'm still not sure why we don't want to do this given that it's an easy change from our side and it can save the user some trouble.
Okay, changed to use min(nchains, Threads.nthreads())
.
That still introduces the same bug if nchains < Threads.nthreads()
. As discussed above, if we want to change something about it, we have to handle nchains < Threads.nthreads()
and nchains >= Threads.nthreads()
in a different way - in the former case copy
should be moved inside the loop whereas in the latter case it should be done outside of the loop for all threads once, for efficiency reasons.
I would prefer not to break the current state since that works efficiently for all cases where nchains >= Threads.nthreads()
.
I'm not strictly against covering the other case as well, I'm just wondering if it's worth the increased code complexity. At least in the use case mentioned above there are better fixes that avoid this problem.
I'm not strictly against covering the other case as well, I'm just wondering if it's worth the increased code complexity. At least in the use case mentioned above there are better fixes that avoid this problem.
If we don't want to make the change, we should at least through a warning here and guide the user to write their own loop. Also, currenly when nchains < Threads.nthreads()
, the progress bar says running parallel chains with Threads.nthreads()
, which is also not true and confusing.
But again, I think it's good for us to take care of this. + writing the customized loop is not trivial when I also want to implement the progress meter correct :sweat_smile:
But again, I think it's good for us to take care of this. + writing the customized loop is not trivial when I also want to implement the progress meter correct
I agree, with a progress bar it becomes a bit more complicated. I didn't think that you would care about it, in my experience it can slow down sampling quite a bit so I assumed you disabled it.
So, I digged a bit deeper into the Julia internals and it seems the current undocumented internal behaviour of @threads
is actually to split the loop in batches of equal size among all threads, starting with thread 1. Since this is not part of the documentation, I guess one is supposed to not rely on this behaviour and it might change at any time. However, I think we could still go with the min(nchains, Threads.nthreads())
workaround for now since it is the simplest solution and seems to work with the current Julia version at least.
I just suggest to
nchains < Threads.nthreads()
Fixes #37.