cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

Marginalizing discrete parameters #107

Open sethaxen opened 4 years ago

sethaxen commented 4 years ago

It might be useful to have a model transformation marginalize_discrete that marginalizes out discrete parameters and control flow when possible to turn a model with latent discrete parameters into one with only continuous parameters that typically can be sampled more efficiently and also can be sampled with HMC. Some examples:

julia> mcat = @model λ begin
           N = length(λ)
           μ ~ Normal(0, 1) |> iid(N)
           σ ~ HalfNormal(1) |> iid(N)
           n ~ Categorical(λ)
           y ~ Normal(μ[n], σ[n])
       end;

julia> mmix = marginalize_discrete(mcat)
@model λ begin
        N = length(λ)
        σ ~ HalfNormal(1) |> iid(N)
        μ ~ Normal(0, 1) |> iid(N)
        y ~ Mixture([Normal(μ[n], σ[n]) for n = eachindex(μ, σ)], λ)
    end

julia> mbern = @model p begin
           N = length(λ)
           n ~ Bernoulli(p)
           y ~ if n == 0
               Normal(0, 1)
           else
               Cauchy(0, 1)
           end
       end;

julia> mmix2 = marginalize_discrete(mbern)
@model p begin
        y ~ Mixture([Normal(0, 1), Cauchy(0, 1)], [1 - p, p])
    end

julia> mif = @model begin
           x ~ Normal(0, 1)
           y ~ if x > 2
               Normal(0, 1)
           else
               Cauchy(0, 1)
           end
       end;

julia> mmix3 = marginalize_discrete(mif)
@model begin
        q = cdf(Normal(0, 1))
        y ~ Mixture([Normal(0, 1), Cauchy(0, 1)], [q, 1 - q])
    end

These examples aren't super amazing. A user could without much effort transform these models themselves. However, if we could work out a simple set of rules for marginalizing more complicated classes of models with discrete parameters and control flow, then we could make things like the change point models marginalization in the Stan user guide automatic. In cases like that though, it's not as simple as a transformation from one parameterization of the model to another. That transformation rewrites the model in terms of accumulating log probabilities.

sethaxen commented 3 years ago

SlicStan does this for Stan, and I imagine it may not be that hard for Soss to support automatic marginalization for the same classes of models. See https://arxiv.org/pdf/2010.11887.pdf.

sethaxen commented 3 years ago

I suspect many users of PPLs are unaware that if they marginalize out their discrete parameters z to sample their continuous parameters θ, they can still recover samples of z as though they had sampled it with MCMC, with p(z | θ, y) = p(y | z, θ) p(z | θ) / p(y | θ). Given the original joint model from the user and the marginalized model, each of these terms are known, so it may not be too difficult to define a function for augmenting a draw from the marginalized posterior with exact draws of the marginalized discrete parameters.

sethaxen commented 3 years ago

Something like this procedure would I think do it. I believe it only requires that z have finite support; then the marginals always take the form of a mixture model, and the conditional always takes the form of a categorical.

# 0. user-provided joint model
π(z,θ₁,θ₂,θ₃,y)
    θ₁ ~ π(θ₁)
    z ~ π(z|θ₁)
    θ₂ ~ π(θ₂|θ₁)
    θ₃ ~ π(θ₃|z,θ₁)
    y ~ π(y|z,θ₁,θ₂,θ₃)

# 1. extract prior of z
    π(z|θ₁)

# 2. generate code for entire markov Blanket dependent on z (i.e. discarding z's dependencies)
    π(y,θ₂,θ₃|z,θ₁)

# 3. generate the mixture model of 2. for all z in the (finite) support of π(z|θ₁)
π(y,θ₂,θ₃|θ₁)
    (y,θ₂,θ₃) ~ Mixture(
        [π(y,θ₂,θ₃|zᵢ,θ₁) for zᵢ in support(π(z|θ₁))],
        [π(zᵢ|θ₁) for zᵢ in support(π(z|θ₁))],
    )

# 4. automatically marginalized model, replaces dependents in markov blanket of z with their generated mixture model
π(θ₁,θ₂,θ₃,y)
    θ₁ ~ π(θ₁)
    (y,θ₂,θ₃) ~ π(y,θ₂,θ₃|θ₁)

# 5. condition π(θ₁,θ₂,θ₃,y) on y to draw from π(θ₁,θ₂,θ₃|y) via MCMC

# 6. Given π(z|θ₁) from 1., π(y,θ₂,θ₃|z,θ₁) from 2., and π(y,θ₂,θ₃|θ₁) from 3., draw exact samples of z
π(z|θ₁,θ₂,θ₃,y)
    i ~ Categorical([π(zᵢ|θ₁)π(y,θ₂,θ₃|zᵢ,θ₁)/π(y,θ₂,θ₃|zᵢ,θ₁) for zᵢ in support(π(z|θ₁))])
    z = support(π(z|θ₁))[i]

# 7. merge draws from π(θ₁,θ₂,θ₃|y) and π(z|θ₁,θ₂,θ₃,y) to get draws from joint π(z,θ₁,θ₂,θ₃|y)
cscherrer commented 3 years ago

Thanks @sethaxen , the details here are really helpful. I guess to this point the marginalization would be an explicit sum? In some cases the algebraic form of this can be simplified, but maybe that's something to worry about later, or potentially to do symbolically. OTOH it could be worth trying to set things up in a way that makes it easy to add methods for marginalization strategies as we find them. It may also be worth anticipating probabilistic circuits / sum-product networks, which can often be marginalized efficiently.

Maybe we should start with a small set of concrete examples. This weekend I'll mostly be working on MeasureTheory to prep for Kusti's workshop on Monday. Adding tests, fixing up MvNormal, etc. But let's get back to this soon :)

cscherrer commented 3 years ago

Oh right, you already gave some concrete examples :)

sethaxen commented 3 years ago

Thanks @sethaxen , the details here are really helpful.

FWIW, the details are mine. I haven't fully processed the SlicStan paper yet; it's possible they do something...slicker.

I guess to this point the marginalization would be an explicit sum? In some cases the algebraic form of this can be simplified, but maybe that's something to worry about later, or potentially to do symbolically. OTOH it could be worth trying to set things up in a way that makes it easy to add methods for marginalization strategies as we find them.

Yes, I had the same thought, since if z is continuous, the result is continuous mixtures, which in some cases are known in closed form. The challenge here would probably be defining and using rules for marginalization (and the inverse we do above to draw samples of z, whatever that should be called).

What I do like about the discrete case given here is that the entire family of discrete distributions with finite support can be handled with a single approach, and it has concrete benefits; by Rao-Blackwell theorem, the resulting posterior estimates are at least as good (and potentially much better). Plus one can use efficient, adaptive gradient-based samplers, so this can only make things better. I also suspect we can simplify with fewer restrictions in Soss than in other PPLs because of models being first-class.

Maybe we should start with a small set of concrete examples. This weekend I'll mostly be working on MeasureTheory to prep for Kusti's workshop on Monday. Adding tests, fixing up MvNormal, etc. But let's get back to this soon :)

Sounds good!