TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

Use different jittered step sizes for vectorized HMC #222

Open sethaxen opened 3 years ago

sethaxen commented 3 years ago

Currently when multiple chains are sampled in parallel with StaticTrajectory, the different chains may (if I understand correctly) have their own nominal step sizes and jitters, yet the same random number is used for all chains: https://github.com/TuringLang/AdvancedHMC.jl/blob/91837c65a20706c9cde7206b3a2e0eeeadc6d184/src/integrator.jl#L170

This couples the jitters across chains and probably should be changed to:

     ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand.(rng) .- 1)) 
sethaxen commented 3 years ago

On second look, the current version does not even support one of the allowed signatures, namely, when rng is a vector of AbstractRNG's, as rand(rng) would then select a random rng from the vector, which cannot be multiplied by an integer.

xukai92 commented 3 years ago

I'm probably not having the best practice here but I overwrote the signature so that I can use rand(rng) for a vector of RNGs at https://github.com/TuringLang/AdvancedHMC.jl/blob/master/src/utilities.jl#L5. The intension is that whenever rng and chains are properly matched, most of the code is unchanged.

I think this definitely need to be revisited.

sethaxen commented 3 years ago

Yeah, that's type piracy and unnecessarily allocates. Just moving the call to rand into the broadcast does what you want. e.g.

julia> using Random

julia> x, y = randn(10), randn(10);

julia> rngs = MersenneTwister.(1:10);

julia> typeof(rngs)
Array{MersenneTwister,1}

julia> x .+ rand.(rngs[1]) ./ y # one RNG is fine
10-element Array{Float64,1}:
  0.6484329512210842
 -0.8456626952265401
 -0.5519045080957454
  1.1699050798703037
  0.23873584522965022
  8.839777756732783
  1.9767957493016919
  0.6796803545103542
  0.8876298219171639
 -1.6656461148395274

julia> x .+ rand.(rngs) ./ y # compatible number of RNGs is fine
10-element Array{Float64,1}:
  0.7572236290032679
 -1.0410338544741133
  0.41222283801769466
  1.7712141238868608
  0.5290394278928535
  4.433214617560498
  1.7686380486229119
  0.3905818340153596
  0.5186372843084085
 -1.367575972395338

julia> rng = MersenneTwister(42);

julia> z1 = x .+ y .* (2 .* rand.(rng) .+ 1);

julia> rng = MersenneTwister(42);

julia> z2 = x .+ y .* (2 .* rand(rng, 10) .+ 1);

julia> z1 == z2 # broadcasting does the same as preallocating the random numbers
true

julia> x .+ rand.(rngs[1:3]) ./ y  # incompatible number of RNGs errors
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 3 and 10")
Stacktrace:
 [1] _bcs1 at ./broadcast.jl:501 [inlined]
 [2] _bcs at ./broadcast.jl:495 [inlined]
 [3] broadcast_shape at ./broadcast.jl:489 [inlined]
 [4] combine_axes at ./broadcast.jl:484 [inlined]
 [5] _axes at ./broadcast.jl:209 [inlined]
 [6] axes at ./broadcast.jl:207 [inlined]
 [7] combine_axes at ./broadcast.jl:484 [inlined]
 [8] instantiate at ./broadcast.jl:266 [inlined]
 [9] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(+),Tuple{Array{Float64,1},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(rand),Tuple{Array{MersenneTwister,1}}},Array{Float64,1}}}}}) at ./broadcast.jl:837
 [10] top-level scope at REPL[10]:1

I'll open a PR.

sethaxen commented 3 years ago

Although perhaps I'm misunderstanding what a vector of RNGs means here. I had interpreted it as a separate RNG per chain, but with rand_coupled it seems the intention is for the RNGs to be coupled during sampling? Not sure how this is used. I wonder if it would make more sense to provide something like

struct CoupledRNG{T<:AbstractVector{<:AbstractRNG}} <: AbstractRNG
    rngs::T
end

and then provide the necessarily overloads so that the methods that call rand don't need to be aware of the coupling, it just happens. That could get tricky on the GPU though, where not every RNG works.

xukai92 commented 3 years ago

Yes, a vector of RNGs is a separate RNG per chain.

The rand_coupled is a seperate concern and only used here: https://github.com/TuringLang/AdvancedHMC.jl/blob/692e646f66307e4d35627b57904e7fe04bb35234/src/trajectory.jl#L302

Basically, in vectorized mode, there are also some cases in which you want a single sample for all chains. You could well use one of the RNG of your vector of RNGs, but then they are not "synced" any more. An example is that if you pass a vector of same RNG and you would expect all chains to be the same. Obviously if you only use one RNG from your vector, you lose the synchronization. This happen a lot in coupled MCMC implementaitons.