TuringLang / MCMCTempering.jl

Implementations of parallel tempering algorithms to augment samplers with tempering capabilities
https://turinglang.org/MCMCTempering.jl/
MIT License
29 stars 3 forks source link

Introduction of compositions and products of samplers and models #151

Closed torfjelde closed 1 year ago

torfjelde commented 1 year ago

As we've spoken about before, it should be possible to make things much more composable than it currently is.

Much of this functionality is useful beyond just MCMCTempering, and so this package will provide a great "testing-ground" for functionality that we can eventually move to something like AbstractMCMC proper.

To achieve the above we need the following:

  1. Compositions of samplers: sampler_outer ∘ sampler_inner whose step first steps sampler_inner and then, using the resulting state from stepping sampler_inner, takes a step with sampler_outer.
  2. Products of samplers and models: sampler_1 × sampler_2 which targets model_1 × model_2 whose step simply runs sampler_1 on model_1 and sampler_2 on model_2 in "parallel".
  3. A state (SequentialState) which represents a sequence of states so we can later "unroll" these to obtain the full chain, if so desired. E.g. in the case of sampler_outer ∘ sampler_inner it's not completely obvious if we only want the iterations corresponding to the full composition, or if we actually want the iterations of each of the samplers.

All in all, it's now possible to do some neat thingies.

Neat thingies

Some setup.

julia> using MCMCTempering

julia> using Distributions: Distributions

julia> using AdvancedMH, LogDensityProblems, MCMCChains

julia> using Random, LinearAlgebra

julia> Random.seed!(42);

julia> # Target of interest.
       struct Problem end

julia> LogDensityProblems.capabilities(::Type{Problem}) = LogDensityProblems.LogDensityOrder{0}()

julia> LogDensityProblems.dimension(::Problem) = 1

julia> function LogDensityProblems.logdensity(::Problem, x)
           Distributions.loglikelihood(
               Distributions.MixtureModel([
                   Distributions.Normal(0, 1),
                   Distributions.Normal(5, 1)
               ]),
               x
           )
       end

julia> model = Problem();

julia> # Make AdvancedMH.jl compatible with MCMCTempering.jl.
       MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp

julia> function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, logprob)
           return AdvancedMH.Transition(params, logprob)
       end

julia> sampler = RWMH(Distributions.MvNormal(zeros(1), I));

Compositions:

julia> # Compose the two samplers to make a single one.
       sampler_composed = sampler ∘ sampler;

julia> # 1000 iterations of the composed sampler, i.e. 2 × 1000 iterations of `sampler`.
       sample(model, sampler_composed, 1000; chain_type=MCMCChains.Chains, param_names=["x"], progress=false)
Chains MCMC chain (2000×2×1 Array{Float64, 3}):

Iterations        = 1:1:2000
Number of chains  = 1
Samples per chain = 2000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

           x    3.0078    2.6408     0.0590    0.3406   31.2179    1.0001

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           x   -1.5962    0.4158    3.9498    5.2639    6.8919

julia> # If we don't want the "intermediate" iterations, i.e. we only want 1000 iterations,
       # we can create the composition as follows:
       sampler_composed = MCMCTempering.CompositionSampler(sampler, sampler, Val(false));

julia> # 1000 iterations of the composed sampler, i.e. 2 × 1000 iterations of `sampler`
       # but only the second step of every iteration is kept.
       sample(model, sampler_composed, 1000; chain_type=MCMCChains.Chains, param_names=["x"], progress=false)
Chains MCMC chain (1000×2×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

           x    2.7230    2.6231     0.0830    0.3842   15.6804    1.0199

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           x   -1.4815    0.2105    3.5516    5.0038    6.7092

Same with "repeated" samplers (useful to avoid insane compilation times):

julia> sampler_repeated = sampler^2;

julia> # 1000 iterations of the composed sampler, i.e. 2 × 1000 iterations of `sampler`.
       sample(model, sampler_repeated, 1000; chain_type=MCMCChains.Chains, param_names=["x"], progress=false)
Chains MCMC chain (2000×2×1 Array{Float64, 3}):

Iterations        = 1:1:2000
Number of chains  = 1
Samples per chain = 2000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

           x    2.5271    2.7248     0.0609    0.3574   27.1560    1.0430

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           x   -1.8869   -0.0602    3.2003    5.0825    6.4611

julia> # If we don't want the "intermediate" iterations, i.e. we only want 1000 iterations,
       # we can create the composition as follows:
       sampler_repeated = MCMCTempering.RepeatedSampler(sampler, 2, Val(false));

julia> # 1000 iterations of the composed sampler, i.e. 2 × 1000 iterations of `sampler`
       # but only the second step of every iteration is kept.
       sample(model, sampler_repeated, 1000; chain_type=MCMCChains.Chains, param_names=["x"], progress=false)
Chains MCMC chain (1000×2×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

           x    1.9781    2.4649     0.0779    0.3164   41.3850    1.0136

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           x   -1.6115   -0.1330    1.2264    4.4097    6.2514

Products:

julia> using AbstractMCMC

julia> # The `×` operator is only overloaded for `AbstractMCMC.AbstractModel`.
       model_product = AbstractMCMC.LogDensityModel(model) × AbstractMCMC.LogDensityModel(model);

julia> sampler_product = sampler × sampler;

julia> # NOTE: Doesn't have a default implementation for `MCMCChains` as it's not quite
       # clear what should be the default behavior. Maybe make multiple chains?
       sample(model_product, sampler_product, 10; progress=false)
10-element Vector{MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}}:
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.6498344983944278], -1.8232280067987499), AdvancedMH.Transition{Vector{Float64}, Float64}([0.07207544062668315], -1.6146778048128319)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([0.6943350457163211], -1.8530163355237135), AdvancedMH.Transition{Vector{Float64}, Float64}([0.07207544062668315], -1.6146778048128319)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.4856077380136814], -1.7299928226472097), AdvancedMH.Transition{Vector{Float64}, Float64}([0.4746381721097502], -1.7246864188969444)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.20411329402151013], -1.632915489112412), AdvancedMH.Transition{Vector{Float64}, Float64}([0.4788031149260946], -1.7266710915779613)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.20411329402151013], -1.632915489112412), AdvancedMH.Transition{Vector{Float64}, Float64}([1.8298153618845627], -3.251746195861665)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.5678596621712501], -1.7733177938402378), AdvancedMH.Transition{Vector{Float64}, Float64}([3.4942173463128237], -2.738864747731279)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.5678596621712501], -1.7733177938402378), AdvancedMH.Transition{Vector{Float64}, Float64}([5.229476346004903], -1.638414227364183)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([0.9966405120515761], -2.108188145179178), AdvancedMH.Transition{Vector{Float64}, Float64}([4.05331874945702], -2.0597648446613412)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([0.9966405120515761], -2.108188145179178), AdvancedMH.Transition{Vector{Float64}, Float64}([3.2685916217912805], -3.0897694811127474)))
 MCMCTempering.MultipleTransitions{Tuple{AdvancedMH.Transition{Vector{Float64}, Float64}, AdvancedMH.Transition{Vector{Float64}, Float64}}}((AdvancedMH.Transition{Vector{Float64}, Float64}([-0.22495986607560436], -1.637387974327772), AdvancedMH.Transition{Vector{Float64}, Float64}([1.9087787661163962], -3.3830907462385853)))

And of course, combining all of these:

julia> map(MCMCTempering.getlogprob, sample(model_product, sampler_product^2 ∘ sampler_product, 10; progress=false))
10-element Vector{Tuple{Float64, Float64}}:
 (-1.6789149556358056, -1.845768140349007)
 (-1.6171440527626668, -1.6122547264896363)
 (-1.6498932983020569, -2.7923619109574487)
 (-1.6367526349992172, -2.798480925358765)
 (-2.0658684182513025, -2.335329268700216)
 (-2.027594046344223, -2.6852130757834116)
 (-1.6303365184063048, -1.617287477974672)
 (-1.6197038503748589, -2.5878347745764883)
 (-1.612259287023933, -1.7229513627531263)
 (-1.6143481830830035, -1.6188045235281523)

MultiSampler also works with AbstractArray, etc.

julia> # The `×` operator is only overloaded for `AbstractMCMC.AbstractModel`.
       model_product = MCMCTempering.MultiModel([AbstractMCMC.LogDensityModel(model), AbstractMCMC.LogDensityModel(model)]);

julia> map(MCMCTempering.getlogprob, sample(model_product, sampler_product^2 ∘ sampler_product, 10; progress=false))
10-element Vector{Vector{Float64}}:
 [-2.2732873163017735, -3.5787911685418816]
 [-1.6150520210876411, -1.6654707680796632]
 [-2.0076876351390585, -1.8670978312229427]
 [-1.6467969287284003, -2.902551018827204]
 [-1.7624721235117182, -3.033980951301354]
 [-1.7530023875843586, -2.274318735634611]
 [-2.2734062358142517, -1.8661973905468405]
 [-2.4250078261264316, -1.6188624484597762]
 [-1.8370108853303015, -1.64247555341679]
 [-2.3451165022255536, -3.9145022552056123]

Drawbacks

All in all, it's worth pointing out that we're not expecting users to nest these "meta"-samplers very often; in practice we'll probably look at a nested depth of 2 or 3 at most (e.g. mixture of composition of repeated).

Next steps

This PR on its own doesn't bring us all the way to just doing sampler_outer ∘ sampler_swapper, but this should now be straight forward to implement by just overloading step for CompositionSampler{<:AbstractMCMC.AbstractSampler,<:TemperedSampler} to perform step for sampler_outer and then a swap-step, finally returning a SequentialTransition and SequentialState, or just a CompositionState with transition from sampler_outer (in the case where we don't want to further inspect the swapping statistics, etc.).

I'll do this in another PR, as it will change TemperedSampler quite a bit.

HarrisonWilde commented 1 year ago

Great stuff, small comment first of all wrt your write up above, could you add a bullet for repeated samplers (sampler^2) similar to the product and composed samplers. This would add a bit of clarity I think as its a lot of new terms.

HarrisonWilde commented 1 year ago

Just had a proper go through and tested it out a bit, all seems to work and all looks great, happy to merge pending the suggestions and comments above.