cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Some proposed changes to advancedHMC and dynamicHMC #114

Open sethaxen opened 4 years ago

sethaxen commented 4 years ago

While reviewing #103, I saw a few things we can improve to standardize advancedHMC and dynamicHMC.

I'll demo with the example in README.md:

julia> using Soss, Random

julia> Random.seed!(3);

julia> m = @model X begin
           β ~ Normal() |> iid(size(X,2))
           y ~ For(eachrow(X)) do x
        Normal(x' * β, 1)
           end
       end;

julia> X = randn(6, 2);

julia> truth = rand(m(X=X));

Now sample the posteriors with DynamicHMC.jl and AdvancedHMC.jl:

julia> post = dynamicHMC(m(X=truth.X), (y=truth.y,));

julia> post2 = advancedHMC(m(X=truth.X), (y=truth.y,));
┌ Warning: `StanHMCAdaptor(n_adapts, pc, ssa)` is deprecated, use `initialize!(StanHMCAdaptor(pc, ssa), n_adapts)` instead.
│   caller = ip:0x0
└ @ Core :-1
Sampling100%|███████████████████████████████| Time: 0:00:01
  iterations:                    1000
  n_steps:                       1
  is_accept:                     true
  acceptance_rate:               7.284347643885779e-6
  log_density:                   -9.534993040048903
  hamiltonian_energy:            14.668903155484923
  hamiltonian_energy_error:      0.0
  max_hamiltonian_energy_error:  11.829782670180457
  tree_depth:                    1
  numerical_error:               false
  step_size:                     1.5438261928453059
  nom_step_size:                 1.5438261928453059
  is_adapt:                      true
  mass_matrix:                   DiagEuclideanMetric([0.06428332836611608, 0.224 ...])
┌ Info: Finished 1000 sampling steps for 1 chains in 1.917936052 (s)
│   h = Hamiltonian(metric=DiagEuclideanMetric([0.06428332836611608, 0.224 ...]))
│   τ = NUTS{MultinomialTS,Generalised}(integrator=Leapfrog(ϵ=0.853), max_depth=10), Δ_max=1000.0)
│   EBFMI_est = 0.8309789700510506
└   average_acceptance_rate = 0.7907287736038069

Note the deprecation warning for StanHMCAdaptor. Now we examine the returned types:

julia> typeof(post)
Array{NamedTuple{(:β,),Tuple{Array{Float64,1}}},1}

julia> post[1:5]
5-element Array{NamedTuple{(:β,),Tuple{Array{Float64,1}}},1}:
 (β = [0.39734466526796725, 0.9980730158207407],)
 (β = [0.45764947899215946, 0.7167597584561338],)
 (β = [0.44310994230948264, 1.1912646425584188],)
 (β = [0.5804128915563641, 0.9052672328462603],) 
 (β = [0.4911904823400375, 1.3336636471494507],) 

julia> typeof(post2)
Tuple{Array{Array{Float64,1},1},Array{NamedTuple,1}}

julia> post2[1][1:5]
5-element Array{Array{Float64,1},1}:
 [-0.7609622756513167, -0.5346927348411568]
 [-0.7609622756513167, -0.5346927348411568]
 [0.6422579162002842, -0.1618249568942679] 
 [0.5959701333920362, 0.7209161785397256]  
 [0.21076555144098075, 0.9986031838617293] 

 julia> post2[2][1:5]
 5-element Array{NamedTuple,1}:
  (n_steps = 2, is_accept = true, acceptance_rate = 0.0, log_density = -30.39670440081333, hamiltonian_energy = 30.650753453316053, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1053.5798626683636, tree_depth = 1, numerical_error = true, step_size = 0.8, nom_step_size = 0.8, is_adapt = true)                                                                   
  (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -30.39670440081333, hamiltonian_energy = 30.49119697647495, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 225971.00015050083, tree_depth = 0, numerical_error = true, step_size = 1.8680518327273066, nom_step_size = 1.8680518327273066, is_adapt = true)                                      
  (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -10.978218893780332, hamiltonian_energy = 27.782397896193544, hamiltonian_energy_error = -2.6945284892033143, max_hamiltonian_energy_error = -2.6945284892033143, tree_depth = 2, numerical_error = false, step_size = 0.18418867766839725, nom_step_size = 0.18418867766839725, is_adapt = true)                
  (n_steps = 7, is_accept = true, acceptance_rate = 0.9877831216290067, log_density = -9.260148748631764, hamiltonian_energy = 11.258014826759604, hamiltonian_energy_error = -0.06469301067649447, max_hamiltonian_energy_error = -0.06469301067649447, tree_depth = 3, numerical_error = false, step_size = 0.19183312443074413, nom_step_size = 0.19183312443074413, is_adapt = true)
  (n_steps = 7, is_accept = true, acceptance_rate = 0.8119717633173144, log_density = -10.150000774214176, hamiltonian_energy = 11.121690046268357, hamiltonian_energy_error = 0.20150601028848492, max_hamiltonian_energy_error = 0.2635036075200077, tree_depth = 2, numerical_error = false, step_size = 0.250565415289986, nom_step_size = 0.250565415289986, is_adapt = true)      

While dynamicHMC gives us named tuples but no stats, advancedHMC gives us a tuple, the first element being an array with no variable names, and the second being the stats.

A few proposed changes:

  1. For advancedHMC, we should fix the deprecation.
  2. I think we should adopt the same return type for the sampled variables as returned by dynamicHMC (i.e. named tuples).
  3. I think we should make the returning of the stats optional. We can either use a different function name or a positional argument for this to remain type stable.
  4. dynamicHMC could adopt a similar format, returning all available stats in named tuples as well if the user requests it.
cscherrer commented 4 years ago

Yeah, this is really tricky. Long-term, I think most MCMC should be iterable. I don't really understand the ArviZ setup yet (Colin Carroll pointed to them on Twitter), but something like this seems sensible to me:

result = (
    samples = ...,
    diagnostics = (
        per_chain = ...,
        per_sample = ...
    )
)

Here samples and per_sample are iterators, while per_chain is "global" (not tied to any particular sample). I'm assuming diagnostics will change (maybe a lot) depending on the sampler.

The problem with this is getting other libraries into this form.

sethaxen commented 4 years ago

Yeah, this is really tricky. Long-term, I think most MCMC should be iterable. Yes I don't currently see much of an issue with this in Julia. One can always wrap an iterator with a method that allocates an array and fills it by taking a fixed number from the iterator.

I don't really understand the ArviZ setup yet (Colin Carroll pointed to them on Twitter), but something like this seems sensible to me:

That looks sensible to me. diagnostics should probably be something more general. ArviZ uses sample_statistics which I like for its generality. In most cases though, per_sample and sample should be synchronized. How would that work here?

How would warm-up (e.g. for HMC) fit into this scheme? Would samples only be returned after warm-up?

What format would samples be? Named tuples? Arrays? Or something else?