TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
78 stars 18 forks source link

Support `init_params` in ensemble methods #94

Closed devmotion closed 2 years ago

devmotion commented 2 years ago

Implements the suggestion in https://github.com/TuringLang/AbstractMCMC.jl/issues/92#issuecomment-1016745828.

With this PR, samplers that support init_params for single chain sampling (such as e.g. AdvancedMH, EllipticalSliceSampling or Turing) get support for init_params with ensemble methods for free. E.g., with this PR and the latest release of EllipticalSliceSampling (no changes needed):

julia> using EllipticalSliceSampling, Distributions

julia> prior = Normal(0, 1);

julia> loglik(x) = logpdf(Normal(x, 0.5), 1.0);                                                                                               

julia> sample(ESSModel(prior, loglik), ESS(), 3; progress=false, init_params=0.5) # `init_params` supported for single chain sampling
3-element Vector{Float64}:
 0.5
 0.5569861002529894
 0.5590132914774115

julia> sample(ESSModel(prior, loglik), ESS(), MCMCSerial(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 1.2604650145668184, 0.5342938313761859]
 [0.4, 0.6766414322931542, 0.6669188733609087]
 [0.2, 0.1991658951820099, 1.620050207287682]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCSerial(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, -0.23546962848230765, 0.28173464269157983]
 [0.0, 0.8450041861878778, 1.2969836926133524]
 [0.0, 0.2803330725356131, 1.5178894526699893]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCThreads(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 0.683876765057849, 0.6000166327171864]
 [0.4, 0.18201765370396472, 0.2882589877099947]
 [0.2, 0.9548385848899827, 0.08087007696252635]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCThreads(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, 0.2830832021306832, -0.11155114222247048]
 [0.0, 0.22734708270317408, 0.908700798723544]
 [0.0, 0.5444555918079804, -0.07055391114874537]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCDistributed(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 1.3010083566517667, 1.2849239490378277]
 [0.4, 0.5001641558540113, 0.4983091461364374]
 [0.2, 0.3023383923085534, 0.9800014107135103]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCDistributed(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, -0.01643158241658191, 0.20824023548102533]
 [0.0, 0.46749889505515185, 0.3357075197741447]
 [0.0, 0.6078789038880229, 0.9077006973215425]

Samplers that do not support init_params are not affected by this PR.

codecov[bot] commented 2 years ago

Codecov Report

Merging #94 (019bb59) into master (4994a79) will decrease coverage by 0.25%. The diff coverage is 95.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #94      +/-   ##
==========================================
- Coverage   97.74%   97.48%   -0.26%     
==========================================
  Files           7        7              
  Lines         222      239      +17     
==========================================
+ Hits          217      233      +16     
- Misses          5        6       +1     
Impacted Files Coverage Δ
src/AbstractMCMC.jl 100.00% <ø> (ø)
src/sample.jl 97.66% <95.00%> (-0.40%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 4994a79...019bb59. Read the comment docs.

devmotion commented 2 years ago

The MCMCChains errors are unrelated and known.

devmotion commented 2 years ago

https://github.com/TuringLang/AbstractMCMC.jl/pull/99 should be merged first since it is non-breaking.