cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

AdvancedMH.jl #91

Open cscherrer opened 4 years ago

cscherrer commented 4 years ago

https://github.com/TuringLang/AdvancedMH.jl

cscherrer commented 4 years ago

This now works:

julia> using Soss

julia> using AdvancedMH

julia> import StatsBase: sample

julia> m = @model n begin
           p ~ Uniform()
           x ~ Bernoulli(p) |> iid(n)
       end;

julia> x = rand(100) .< 0.2;

julia> function sample(dist::Soss.JointDistribution, obs, spl::MetropolisHastings, n; kwargs...)
           logjoint(proposal) = logpdf(dist, merge(proposal,obs))
           return sample(AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
       end
sample (generic function with 26 methods)

julia> proposal = (p = StaticProposal(Uniform()),)
(p = StaticProposal{Uniform{Float64}}(Uniform{Float64}(a=0.0, b=1.0)),)

julia> post = sample(m(n=100), (x=x,), MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple});

julia> particles(post)
(param_1 = 0.259 ± 0.096, lp = -54.1 ± 10.0)

Questions/concerns:

Any thoughts @DilumAluthge ?

cpfiffer commented 4 years ago

Here's how you'd enable parallel sampling:

using Pkg; Pkg.activate(".")

using Soss

using AdvancedMH, AbstractMCMC

import StatsBase: sample
import AbstractMCMC: AbstractMCMCParallel
import Random: AbstractRNG, GLOBAL_RNG

m = @model n begin
    p ~ Uniform()
    x ~ Bernoulli(p) |> iid(n)
end;

x = rand(100) .< 0.2;

function sample(
    dist::Soss.JointDistribution, 
    obs, 
    spl::MetropolisHastings, 
    args...;
    kwargs...
)
    return sample(GLOBAL_RNG, dist, obs, spl, args...; kwargs...)
end

function sample(
    rng::AbstractRNG,
    dist::Soss.JointDistribution, 
    obs, 
    spl::MetropolisHastings, 
    pmethod::AbstractMCMCParallel, 
    n,
    n_chains; 
    kwargs...
)
    logjoint(proposal) = logpdf(dist, merge(proposal,obs))
    return sample(rng, AdvancedMH.DensityModel(logjoint), spl, pmethod, n, n_chains; kwargs...)
end

function sample(
    rng::AbstractRNG, 
    dist::Soss.JointDistribution, 
    obs, 
    spl::MetropolisHastings, 
    n; 
    kwargs...
)
    logjoint(proposal) = logpdf(dist, merge(proposal,obs))
    return sample(rng, AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
end

proposal = (p = StaticProposal(Uniform()),)

# Serial sampling
post1 = sample(m(n=100), (x=x,), MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple});

# Parallel sampling
post2 = sample(m(n=100), (x=x,), MetropolisHastings(proposal), MCMCThreads(),100, 4; chain_type=Vector{NamedTuple}, );

# Summary stats
p1 = particles(post1)
p2 = map(particles, post2)

I added an extra method to handle a rng in case users pass in explicit seeds or whatever.

The param_1 thing is a bug because I used an idiotic heuristic to name parameters. I'll fix that upstream. Good catch on that one.

Also, side note -- Soss is really slick. Excellent start up time, and the speed you were able to plug AdvancedMH in is pretty remarkable.

cscherrer commented 4 years ago

Wow, that looks great! Thanks for the help with this.

Soss can mostly do a lot because of I made such a conscious effort to be able to connect with all the great packages out there. Shoulders of giants, and all that :)

cscherrer commented 4 years ago

Aaaannndd it works! Holy crap that's cool. Thanks @cpfiffer

cpfiffer commented 4 years ago

No problem. I'll let you know when the AdvancedMH fix is out -- it should be fairly quick.

cscherrer commented 4 years ago

Since this needs new dependencies, I think it should be a new SossMH package

cpfiffer commented 4 years ago

Works for me. Do you expect the primary use case is the NamedTuple-style of specifying priors?

cpfiffer commented 4 years ago

While I'm fixing this, the connection to MCMCChains is probably not going to work super well because it requires linearized output and the corresponding parameter names. Do you have a mechanism to handle linearization somewhere?

cscherrer commented 4 years ago

Do you expect the primary use case is the NamedTuple-style of specifying priors?

I guess you mean proposals? Hmm... I'd think mostly it would either be that or walking in R^n and transforming each step to a named tuple. We'll have the transformation statically. That's nice because we could do things like walk around in a unit interval without accidentally leaving the support. OTOH if the support is smooth we usually want HMC anyway.

Do you have a mechanism to handle linearization somewhere?

I don't understand, what do you mean by linearization in this context?

cpfiffer commented 4 years ago

I guess you mean proposals? Hmm... I'd think mostly it would either be that or walking in R^n and transforming each step to a named tuple. We'll have the transformation statically. That's nice because we could do things like walk around in a unit interval without accidentally leaving the support. OTOH if the support is smooth we usually want HMC anyway.

This is what I mean by "linearization ", I think -- put all the parameters in a vector. It's nice that you have this because you can also use Emcee for free, I believe, as long as you can define your model function as a mapping from the R^n parameter vector to the log density.

cpfiffer commented 4 years ago

Another update: AdvancedMH 0.5.3 fixes the parameter naming issue.

cscherrer commented 3 years ago

@cpfiffer

julia> using AdvancedMH

julia> import StatsBase: sample

julia> m = @model n begin
                  p ~ Uniform()
                  x ~ Bernoulli(p) |> iid(n)
              end;
WARNING: both TransformVariables and MeasureTheory export "∞"; uses of it in module Soss must be qualified

julia> x = rand(100) .< 0.2;

julia> function sample(dist::Soss.ConditionalModel, spl::MetropolisHastings, n; kwargs...)
                  logjoint(proposal) = logdensity(dist, proposal)
                  return sample(AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
              end
sample (generic function with 25 methods)

julia> proposal = (p = StaticProposal(Uniform()),)
(p = StaticProposal{Uniform{(), Tuple{}}}(Uniform()),)

julia> cm = m(n=100)| (x=x,)
ConditionalModel given
    arguments    (:n,)
    observations (:x,)
@model n begin
        p ~ Uniform()
        x ~ Bernoulli(p) |> iid(n)
    end

julia> logjoint(proposal) = logdensity(cm, proposal)
logjoint (generic function with 1 method)

julia> using Random

julia> rng = Random.GLOBAL_RNG
Random._GLOBAL_RNG()

julia> AdvancedMH.propose(rng, MetropolisHastings(proposal), AdvancedMH.DensityModel(logjoint))
ERROR: MethodError: no method matching propose(::Random._GLOBAL_RNG, ::StaticProposal{Uniform{(), Tuple{}}}, ::DensityModel{typeof(logjoint)})
Closest candidates are:
  propose(::AbstractRNG, ::MALA{var"#s23"} where var"#s23"<:AdvancedMH.Proposal, ::DensityModel, ::AdvancedMH.GradientTransition) at /home/chad/.julia/packages/AdvancedMH/Wi3Ba/src/MALA.jl:29
  propose(::AbstractRNG, ::MALA, ::Any) at /home/chad/.julia/packages/AdvancedMH/Wi3Ba/src/MALA.jl:27
  propose(::AbstractRNG, ::Ensemble, ::DensityModel) at /home/chad/.julia/packages/AdvancedMH/Wi3Ba/src/emcee.jl:44
  ...
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/AdvancedMH/Wi3Ba/src/mh-core.jl:0 [inlined]
 [2] _propose(rng::Random._GLOBAL_RNG, proposal::NamedTuple{(:p,), Tuple{StaticProposal{Uniform{(), Tuple{}}}}}, model::DensityModel{typeof(logjoint)})
   @ AdvancedMH ~/.julia/packages/AdvancedMH/Wi3Ba/src/mh-core.jl:116
 [3] propose(rng::Random._GLOBAL_RNG, spl::MetropolisHastings{NamedTuple{(:p,), Tuple{StaticProposal{Uniform{(), Tuple{}}}}}}, model::DensityModel{typeof(logjoint)})
   @ AdvancedMH ~/.julia/packages/AdvancedMH/Wi3Ba/src/mh-core.jl:102
 [4] top-level scope
   @ REPL[12]:1
cscherrer commented 3 years ago

Same error as

julia> post = sample(cm, MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple});
ERROR: MethodError: no method matching propose(::Random._GLOBAL_RNG, ::StaticProposal{Uniform{(), Tuple{}}}, ::DensityModel{var"#logjoint#6"{Soss.ConditionalModel{NamedTuple{(:n,), T} where T<:Tuple, TypeEncoding(begin
    p ~ Uniform()
    x ~ Bernoulli(p) |> iid(n)
end), TypeEncoding(Main), NamedTuple{(:n,), Tuple{Int64}}, NamedTuple{(:x,), Tuple{BitVector}}}}})
cscherrer commented 3 years ago

I think this might be because it expects Distributions.jl

cscherrer commented 3 years ago

confirmed, it works if I instead do

proposal = (p = StaticProposal(Dists.Uniform()),)

Full code:

using Soss
using AdvancedMH
import StatsBase: sample

m = @model n begin
           p ~ Uniform()
           x ~ Bernoulli(p) |> iid(n)
       end;

x = rand(100) .< 0.2;

function sample(dist::Soss.ConditionalModel, spl::MetropolisHastings, n; kwargs...)
           logjoint(proposal) = logdensity(dist, proposal)
           return sample(AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
       end

proposal = (p = StaticProposal(Dists.Uniform()),)

post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple});
cpfiffer commented 3 years ago

Oh, perfect!

cscherrer commented 3 years ago

Now I need to figure out bundle_samples

cscherrer commented 3 years ago

I see a few places it's used. Which would make a good template? https://github.com/TuringLang/AdvancedMH.jl/search?q=bundle

cpfiffer commented 3 years ago

If you want MCMCChains, probably this one:

https://github.com/TuringLang/AdvancedMH.jl/blob/4f45089ed1a3c57c18bd2e56ace60039ba69fc86/src/mcmcchains-connect.jl

cscherrer commented 3 years ago

How would I call sample here so it returns Chains?

cpfiffer commented 3 years ago

This should work:

using MCMCChains

post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100; chain_type=Chains)
cscherrer commented 3 years ago

I thought so too :(

julia> post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100; chain_type=Chains)
ERROR: MethodError: no method matching Chains(::Vector{Vector{Any}}, ::Vector{Symbol}, ::NamedTuple{(:internals,), Tuple{Vector{Symbol}}})
Closest candidates are:
  Chains(::AbstractArray{var"#s6", 3} where var"#s6"<:Union{Missing, Real}, ::AbstractVector{T} where T, ::Any...; kwargs...) at /home/chad/.julia/packages/MCMCChains/qmFXz/src/chains.jl:18
  Chains(::A, ::L, ::K, ::I) where {T, A<:(AxisArrays.AxisArray{T, 3, D, Ax} where {D, Ax}), L, K<:NamedTuple, I<:NamedTuple} at /home/chad/.julia/packages/MCMCChains/qmFXz/src/MCMCChains.jl:55
  Chains(::AbstractArray{var"#s9", 3} where var"#s9"<:Union{Missing, Real}, ::AbstractVector{Symbol}, ::Any; start, thin, evidence, info) at /home/chad/.julia/packages/MCMCChains/qmFXz/src/chains.jl:28
  ...
cscherrer commented 3 years ago

Hmm, and I think I must need some new methods somewhere so I can use a Soss.ConditionalModel instead of a DensityModel. Otherwise lots of type piracy :)

cpfiffer commented 3 years ago

Oh, I see. What's the type getting spat out by the model? Seems like it's Vector{Any} which MCMCChains is to restricted to handle.

cscherrer commented 3 years ago
julia> post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple})
100-element Vector{NamedTuple{(:p, :lp), Tuple{Float64, Float64}}}:
 (p = 0.6702278381081601, lp = -96.04179815494047)
 (p = 0.5609245495559463, lp = -77.1651833115178)
 (p = 0.33464016269646035, lp = -55.175443578669906)
 (p = 0.3337120514484624, lp = -55.1236467535386)
...
cscherrer commented 3 years ago

Or

julia> post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100)
100-element Vector{AdvancedMH.Transition{NamedTuple{(:p,), Tuple{Float64}}, Float64}}:
 AdvancedMH.Transition{NamedTuple{(:p,), Tuple{Float64}}, Float64}((p = 0.27990056314440115,), -52.68064989215473)
 AdvancedMH.Transition{NamedTuple{(:p,), Tuple{Float64}}, Float64}((p = 0.27990056314440115,), -52.68064989215473)
 AdvancedMH.Transition{NamedTuple{(:p,), Tuple{Float64}}, Float64}((p = 0.27990056314440115,), -52.68064989215473)
 AdvancedMH.Transition{NamedTuple{(:p,), Tuple{Float64}}, Float64}((p = 0.27990056314440115,), -52.68064989215473)
...
cpfiffer commented 3 years ago

Okay. It seems to be erroneously casting up to Any, odd!

cpfiffer commented 3 years ago

I'll take a look at this later tonight and get back to you.

cpfiffer commented 3 years ago

Which branch do you want me to test this on? ConditionalModel doesn't seem to be defined -- I'm assuming cs-conditional?

cscherrer commented 3 years ago

Currently dev, should be getting that to master soon

cpfiffer commented 3 years ago

Ugh, sorry, I keep running into rando problems trying to start the dev environment up. Here's my current error:

ERROR: LoadError: MethodError: no method matching ^(::Bernoulli{Float64}, ::Int64)
Closest candidates are:
  ^(::Irrational{:ℯ}, ::Integer) at mathconstants.jl:91
  ^(::Irrational{:ℯ}, ::Number) at mathconstants.jl:91
  ^(::SimplePolynomials.SimplePolynomial, ::S) where S<:Integer at /home/cameron/.julia/packages/SimplePolynomials/nePHr/src/arithmetic.jl:69
  ...
Stacktrace:
 [1] (::Soss.var"#33#34"{Int64})(::Bernoulli{Float64}) at /home/cameron/.julia/dev/Soss/src/distributions/iid.jl:3
 [2] |>(::Bernoulli{Float64}, ::Soss.var"#33#34"{Int64}) at ./operators.jl:834
 [3] macro expansion at /home/cameron/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121 [inlined]
 [4] _logdensity(::Type{TypeEncoding(Main)}, ::Model{NamedTuple{(:n,),T} where T<:Tuple,TypeEncoding(begin
    p ~ Uniform()
    x ~ Bernoulli(p) |> iid(n)
end),TypeEncoding(Main)}, ::NamedTuple{(:n,),Tuple{Int64}}, ::NamedTuple{(:x,),Tuple{BitArray{1}}}, ::NamedTuple{(:p,),Tuple{Float64}}) at /home/cameron/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121
 [5] logdensity at /home/cameron/.julia/dev/Soss/src/primitives/logdensity.jl:7 [inlined]
 [6] logjoint at /home/cameron/code/misc/soss/sample.jl:18 [inlined]
 [7] logdensity at /home/cameron/.julia/packages/AdvancedMH/Wi3Ba/src/AdvancedMH.jl:52 [inlined]
 [8] AdvancedMH.Transition(::DensityModel{var"#logjoint#6"{Soss.ConditionalModel{NamedTuple{(:n,),T} where T<:Tuple,TypeEncoding(begin
    p ~ Uniform()
    x ~ Bernoulli(p) |> iid(n)
end),TypeEncoding(Main),NamedTuple{(:n,),Tuple{Int64}},NamedTuple{(:x,),Tuple{BitArray{1}}}}}}, ::NamedTuple{(:p,),Tuple{Float64}}) at /home/cameron/.julia/packages/AdvancedMH/Wi3Ba/src/AdvancedMH.jl:49
 [9] propose at /home/cameron/.julia/packages/AdvancedMH/Wi3Ba/src/mh-core.jl:103 [inlined]

The code I'm running:

using Soss
using AdvancedMH
using MeasureTheory
using Distributions
import StatsBase: sample

m = @model n begin
    p ~ Uniform()
    x ~ Bernoulli(p) |> iid(n)
end;

x = rand(100) .< 0.2;

proposal = (p = StaticProposal(Distributions.Uniform()),)

function sample(dist::Soss.ConditionalModel, spl::MetropolisHastings, n; kwargs...)
    logjoint(proposal) = logdensity(dist, proposal)
    return sample(AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
end

post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),100; chain_type=Vector{NamedTuple});
cscherrer commented 3 years ago

That.... is very weird. Which MeasureTheory? Try master or dev

cscherrer commented 3 years ago

or main or whatever

cscherrer commented 3 years ago

Sorry for the trouble. I'm spinning too many plates, lose track sometimes

cpfiffer commented 3 years ago

Yup, updating to master for MeasureTheory worked.

Sorry for the trouble. I'm spinning too many plates, lose track sometimes

No trouble at all! Part of the process if you ask me.

cpfiffer commented 3 years ago

It's a bug on the MCMCChains side. I'll fix it and flag a release in a bit.

cpfiffer commented 3 years ago

For the record or if you want to play with this locally, the bundle_samples function I needed was

function AbstractMCMC.bundle_samples(
    ts::Vector{<:Transition{<:NamedTuple}},
    model::DensityModel,
    sampler::MHSampler,
    state,
    chain_type::Type{Chains};
    param_names=missing,
    kwargs...
)
    # Convert to a Vector{NamedTuple} first
    nts = AbstractMCMC.bundle_samples(ts, model, sampler, state, Vector{NamedTuple}; param_names=param_names, kwargs...)

    # Get all the keys
    all_keys = unique(mapreduce(collect∘keys, vcat, nts))

    # Preallocate array
    # vals = []

    # Push linearized draws onto array
    trygetproperty(thing, key) = key in keys(thing) ? getproperty(thing, key) : missing
    vals = map(nt -> [trygetproperty(nt, k) for k in all_keys], nts)

    # Check if we received any parameter names.
    if ismissing(param_names)
        param_names = all_keys
    else
        # Generate new array to be thread safe.
        param_names = Symbol.(param_names)
    end

    # Bundle everything up and return a Chains struct.
    return Chains(vals, param_names, (internals = [:lp],))
end

I'll put this into review shortly.

cpfiffer commented 3 years ago

The PR for those interested: https://github.com/TuringLang/AdvancedMH.jl/pull/49

cpfiffer commented 3 years ago

This code should now work after updating to AdvancedMH >= 0.5.7. Tested on master branches for MeasureTheory and Soss.

using Soss
using AdvancedMH
using MeasureTheory
using Distributions
using MCMCChains
import StatsBase: sample

m = @model n begin
    p ~ MeasureTheory.Uniform()
    x ~ MeasureTheory.Bernoulli(p) |> iid(n)
end;

x = rand(100) .< 0.2;

proposal = (p = StaticProposal(Distributions.Uniform()),)

function sample(dist::Soss.ConditionalModel, spl::MetropolisHastings, n; kwargs...)
    logjoint(proposal) = logdensity(dist, proposal)
    return sample(AdvancedMH.DensityModel(logjoint), spl, n; kwargs...)
end

post = sample(m(n=100)| (x=x,), MetropolisHastings(proposal),10_000; chain_type=Chains);
cscherrer commented 3 years ago

Is :lp supposed to be part of value?

julia> post.value
3-dimensional AxisArray{Float64,3,...} with axes:
    :iter, 1:1:10000
    :var, [:p, :lp]
    :chain, 1:1
And data, a 10000×2×1 Array{Float64, 3}:
[:, :, 1] =
 0.747519  -120.275
 0.747519  -120.275
 0.747519  -120.275
 0.151418   -43.9952
 0.151418   -43.9952
cpfiffer commented 3 years ago

Yes, by convention, but it's stored in a separate "section". If you run describe(post) it should only show the stats for p.

cscherrer commented 3 years ago

It's not showing any:

julia> describe(post)
2-element Vector{ChainDataFrame}:
 Summary Statistics (1 x 7)
 Quantiles (1 x 6)
cpfiffer commented 3 years ago

Oh crap, that's annoying. Try summarystats(post).

cscherrer commented 3 years ago

I'm not sure what you mean by "section", but this seems like it's in with the parameters. Using lp in the model breaks things :(

cscherrer commented 3 years ago

Ah ok summarystats works

cscherrer commented 3 years ago

oops

cpfiffer commented 3 years ago

"section" is the MCMCChains way of saying some values are distinct from others -- by default, everything goes into the parameters section. Internal stuff like lp and maybe HMC diagnostics go into internals, but they are all stored in the same array backend.