Open sethaxen opened 3 years ago
On second look, the current version does not even support one of the allowed signatures, namely, when rng
is a vector of AbstractRNG
's, as rand(rng)
would then select a random rng from the vector, which cannot be multiplied by an integer.
I'm probably not having the best practice here but I overwrote the signature so that I can use rand(rng)
for a vector of RNGs at https://github.com/TuringLang/AdvancedHMC.jl/blob/master/src/utilities.jl#L5. The intension is that whenever rng
and chains are properly matched, most of the code is unchanged.
I think this definitely need to be revisited.
Yeah, that's type piracy and unnecessarily allocates. Just moving the call to rand
into the broadcast does what you want. e.g.
julia> using Random
julia> x, y = randn(10), randn(10);
julia> rngs = MersenneTwister.(1:10);
julia> typeof(rngs)
Array{MersenneTwister,1}
julia> x .+ rand.(rngs[1]) ./ y # one RNG is fine
10-element Array{Float64,1}:
0.6484329512210842
-0.8456626952265401
-0.5519045080957454
1.1699050798703037
0.23873584522965022
8.839777756732783
1.9767957493016919
0.6796803545103542
0.8876298219171639
-1.6656461148395274
julia> x .+ rand.(rngs) ./ y # compatible number of RNGs is fine
10-element Array{Float64,1}:
0.7572236290032679
-1.0410338544741133
0.41222283801769466
1.7712141238868608
0.5290394278928535
4.433214617560498
1.7686380486229119
0.3905818340153596
0.5186372843084085
-1.367575972395338
julia> rng = MersenneTwister(42);
julia> z1 = x .+ y .* (2 .* rand.(rng) .+ 1);
julia> rng = MersenneTwister(42);
julia> z2 = x .+ y .* (2 .* rand(rng, 10) .+ 1);
julia> z1 == z2 # broadcasting does the same as preallocating the random numbers
true
julia> x .+ rand.(rngs[1:3]) ./ y # incompatible number of RNGs errors
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 3 and 10")
Stacktrace:
[1] _bcs1 at ./broadcast.jl:501 [inlined]
[2] _bcs at ./broadcast.jl:495 [inlined]
[3] broadcast_shape at ./broadcast.jl:489 [inlined]
[4] combine_axes at ./broadcast.jl:484 [inlined]
[5] _axes at ./broadcast.jl:209 [inlined]
[6] axes at ./broadcast.jl:207 [inlined]
[7] combine_axes at ./broadcast.jl:484 [inlined]
[8] instantiate at ./broadcast.jl:266 [inlined]
[9] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(+),Tuple{Array{Float64,1},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(rand),Tuple{Array{MersenneTwister,1}}},Array{Float64,1}}}}}) at ./broadcast.jl:837
[10] top-level scope at REPL[10]:1
I'll open a PR.
Although perhaps I'm misunderstanding what a vector of RNGs means here. I had interpreted it as a separate RNG per chain, but with rand_coupled
it seems the intention is for the RNGs to be coupled during sampling? Not sure how this is used. I wonder if it would make more sense to provide something like
struct CoupledRNG{T<:AbstractVector{<:AbstractRNG}} <: AbstractRNG
rngs::T
end
and then provide the necessarily overloads so that the methods that call rand
don't need to be aware of the coupling, it just happens. That could get tricky on the GPU though, where not every RNG works.
Yes, a vector of RNGs is a separate RNG per chain.
The rand_coupled
is a seperate concern and only used here: https://github.com/TuringLang/AdvancedHMC.jl/blob/692e646f66307e4d35627b57904e7fe04bb35234/src/trajectory.jl#L302
Basically, in vectorized mode, there are also some cases in which you want a single sample for all chains. You could well use one of the RNG of your vector of RNGs, but then they are not "synced" any more. An example is that if you pass a vector of same RNG and you would expect all chains to be the same. Obviously if you only use one RNG from your vector, you lose the synchronization. This happen a lot in coupled MCMC implementaitons.
Currently when multiple chains are sampled in parallel with
StaticTrajectory
, the different chains may (if I understand correctly) have their own nominal step sizes and jitters, yet the same random number is used for all chains: https://github.com/TuringLang/AdvancedHMC.jl/blob/91837c65a20706c9cde7206b3a2e0eeeadc6d184/src/integrator.jl#L170This couples the jitters across chains and probably should be changed to: