cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

Implemention of Gaussian mixture model fails when sampling the posterior #296

Open gzagatti opened 3 years ago

gzagatti commented 3 years ago

I got quite excited about the Soss.jl presentation in JuliaConn 2021 that I decided to give it a go by implementing a mixture of Gaussian model. I borrowed the example from Turing.jl for comparison purposes. Please let me know if there is anything that could be improved.

using Plots
    using MeasureTheory
    using Soss
    using SampleChainsDynamicHMC
    import Distributions

# model definition
m = @model N begin
    # parameters
    σ0 ~ Lebesgue(ℝ)
    μ0 ~ Lebesgue(ℝ)
    α ~ Lebesgue(ℝ₊)
    K = 2
    # random variables
    μ ~ For(K) do _ Normal(μ0, σ0) end
    w ~ Distributions.Dirichlet(K, α)
    z ~ For(N) do _ Distributions.Categorical(w) end
    x ~ For(z) do zi Distributions.MvNormal([μ[zi], μ[zi]], 1.) end 
end

# data generation without any assumptions
generative = predictive(m, :σ0, :μ0, :α)
prior_data = rand(generative(N=30, x=data, σ0=1., μ0=0., α=1.))
scatter(
    [xi[1] for xi in prior_data.x], 
    [xi[2] for xi in prior_data.x], 
    legend = false, 
    color = prior_data.z,
    title = "synthetic data"
)
scatter!([prior_data.μ[1]], [prior_data.μ[1]], color=:yellow)
scatter!([prior_data.μ[2]], [prior_data.μ[2]], color=:yellow)

# data generation with assumption on μ and w
predx = predictive(m, :μ, :w)
data = rand(predx(N=30,  μ=[-3.5, 0.0], w=[0.5, 0.5]))
scatter(
    [xi[1] for xi in data.x], 
    [xi[2] for xi in data.x], 
    legend = false, 
    color = data.z,
    title = "synthetic data"
)
scatter!([-3.5], [-3.5], color=:yellow)
scatter!([0.0], [0.0], color=:yellow)

# estimating the posterior
posterior = m(N=30)|(x=data.x,)
l(x) = logdensity(posterior, x)
tr = xform(posterior) # raises Not implemented error
chain = newchain(3, DynamicHMC, l, tr)
sample!(chain, 100)

Unfortunately, I was not able to estimate the posterior distribution as an error is raised when I call tr = xform(posterior). Apparently, xform is not defined for the Dirichlet distribution.

As I understand, not all distributions have been implemented directly in Soss.jl. If that is something within my capabilities, I could try to implement it. However, I have no clue what xform is doing and I have not been able to find a lot of documentation on it.

In any case, thanks for putting the package together.

cscherrer commented 3 years ago

Thanks for letting me know about this. The stack trace from xform includes this line:

 [2] xform(d::Distributions.Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}, _data::NamedTuple{(), Tuple{}})
   @ Soss ~/git/Soss.jl/src/primitives/xform.jl:79

Following that takes you to

function xform(d, _data::NamedTuple)
    if hasmethod(support, (typeof(d),))
        return asTransform(support(d)) 
    end

    error("Not implemented:\nxform($d)")
end

The problem here is that Distributions.Dirichlet has no support method, so it falls through and throws the error. So you're right that the fix is to add this method.

xform is kind of a legacy name. It was originally going to be transform, but that name was already taken in TransformVariables.jl. But since this was built, I'm realizing this should have just been called as, since it has the same functionality as that function from TransformVariables. Docs on that are here: https://tamaspapp.eu/TransformVariables.jl/dev/#The-as-constructor-and-aggregations

Also, in the current setup, the _data argument is only used when you have nested models. I'll be cleaning up the dispatch patterns for this, but for a quick fix let's just go with it.

Anyway, the missing method is

Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TransformVariables.UnitSimplex(length(d.alpha))

But this doesn't fix everything, because you still have

z ~ For(N) do _ Distributions.Categorical(w) end

This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

We'll be adding ways to make this easier in MeasureTheory, but for now in Distributions, say you have

julia> μ = rand(Normal() |> iid(3))
3-element Vector{Float64}:
  1.0510484386874308
 -0.8007745046155319
  0.48629964893183536

Then you can do using FillArrays, MappedArrays and then

julia> paramvec = mappedarray(μ) do μj begin (Fill(μj, 2), 1.) end end
3-element mappedarray(var"#7#8"(), ::Vector{Float64}) with eltype Tuple{Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}:
 (Fill(1.0510484386874308, 2), 1.0)
 (Fill(-0.8007745046155319, 2), 1.0)
 (Fill(0.48629964893183536, 2), 1.0)

These are the mixture components, which you can combine like

julia> Dists.MixtureModel(Dists.MvNormal, paramvec)
MixtureModel{Distributions.MvNormal}(K = 3)
components[1] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(1.0510484386874308, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

components[2] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-0.8007745046155319, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

components[3] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.48629964893183536, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

This uses equal weights, but you can change that by adding another parameter.

Anyway, this works but it's not pretty. We're working on making it much easier to stay in MeasureTheory for all of this.

using MeasureTheory
using Soss
using SampleChainsDynamicHMC
import Distributions
using FillArrays
using LinearAlgebra

m = @model N begin
    σ0 ~ Lebesgue(ℝ)
    μ0 ~ Lebesgue(ℝ)
    α ~ Lebesgue(ℝ₊)
    K = 2
    μ ~ Normal(μ0, σ0) |> iid(K)
    w ~ Distributions.Dirichlet(K, abs(α))
    xdist = Dists.MixtureModel(Dists.Normal, μ, w)
    x ~ Dists.MatrixReshaped(Dists.Product(Fill(xdist, K*N)), K, N)
end

using TransformVariables
const TV = TransformVariables
Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TV.UnitSimplex(length(d.alpha))

prior_data = predict(m(N=30), (N=30, σ0=1., μ0=0., α=1.))

# data generation with assumption on μ and w
predx = predictive(m, :μ, :w)
data = predict(m(N=30), (μ=[-3.5, 0.0], w=[0.5, 0.5]))

# estimating the posterior
posterior = m(N=30)|(x=data.x,)

sample(posterior, dynamichmc())
gzagatti commented 3 years ago

@cscherrer many thanks for the detailed answer.

I made some further explorations on my own and had a few issues:

  1. I went through the transform documentation. I do not quite understand the implementation. For instance with asℝ, I do not quite get why they are using the exponential distribution.

  2. for some reason Fill inside of the proposed model tends to replicate a single draw from the mixed distribution K*N rather than performing K*N draws. When I plotted the samples from prior_data I basically got two points. If I switch to base fill the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.

  3. I attempted a different variation of the model as following. Since the Distributions.jl package does not have a definition for the Product of a multivariate, I implemented a basic version to get the work done:

    struct Product{
            S<:Distributions.ValueSupport,
            T<:Distributions.MultivariateDistribution{S},
            V<:AbstractVector{T},
           } <: Distributions.MultivariateDistribution{S}
        v::V
        function Product(v::V) where
            V<:AbstractVector{T} where
            T<:Distributions.MultivariateDistribution{S} where
            S<:Distributions.ValueSupport
            return new{S, T, V}(v)
        end
    end
    
    Base.length(d::Product) = length(d.v)
    function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}}
        return eltype(T)
    end
    
    function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector)
        broadcast!(dn->rand(rng, dn), x, d.v)
    end
    
    function _logpdf(d::Product, x::AbstractVector)
        sum(n -> logpdf(d.v[n], x[n]), 1:length(d))
    end
    
    function Distributions.rand(rng::Distributions.AbstractRNG, s::Product)
        _rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s)))
    end

    I then redifined the model as following:

      components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end
      mixture = Dists.MixtureModel(Dists.MvNormal, components, w)
      x ~ Product(fill(mixture, N))

    Unfortunately, I still have problems with running the chain. It complains the method _logpdf is not implemented even though I did implement it.

  4. I could not use the function sample. Issue #293 was raised with a similar problem.

  5. I am not quite sure what you meant by:

    This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.

cscherrer commented 3 years ago

@cscherrer many thanks for the detailed answer.

I made some further explorations on my own and had a few issues:

1. I went through the transform documentation. I do not quite understand the implementation. For instance with `asℝ`, I do not quite get why they are using the exponential distribution.

Sorry, I don't understand. Can you point me to a line?

2. for some reason `Fill` inside of the proposed model tends to replicate a single draw from the mixed distribution `K*N` rather than performing `K*N` draws. When I plotted the samples from `prior_data` I basically got two points. If I switch to base `fill` the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.

Ah ok, I think the issue here is that there's no way to tell FillArrays that sampling is nondeterministic. Ideally Distributions would account for this, but it seems they don't. I guess this is the point on filldist in DistributionsAD. And come to think of it, you probably need that anyway since it will make gradients much more efficient for Distributions.

In MeasureTheory we'll have all of this built in. Currently we don't have any custom AD, but the implementations are also much simpler, so AD should have an easier time of it. We'll be adding more optimized methods as we go.

3. I attempted a different variation of the model as following. Since the `Distributions.jl` package does not have a definition for the [`Product` of a multivariate](https://github.com/JuliaStats/Distributions.jl/blob/59df675409a7e2490e4a45edd32c0267df435c55/src/multivariate/product.jl), I implemented a basic version to get the work done:
   ```julia
   struct Product{
          S<:Distributions.ValueSupport,
          T<:Distributions.MultivariateDistribution{S},
          V<:AbstractVector{T},
          } <: Distributions.MultivariateDistribution{S}
       v::V
       function Product(v::V) where
          V<:AbstractVector{T} where
          T<:Distributions.MultivariateDistribution{S} where
          S<:Distributions.ValueSupport
          return new{S, T, V}(v)
       end
   end

   Base.length(d::Product) = length(d.v)
   function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}}
       return eltype(T)
   end

   function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector)
       broadcast!(dn->rand(rng, dn), x, d.v)
   end

   function _logpdf(d::Product, x::AbstractVector)
       sum(n -> logpdf(d.v[n], x[n]), 1:length(d))
   end

   function Distributions.rand(rng::Distributions.AbstractRNG, s::Product)
       _rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s)))
   end
   ```

   I then redifined the model as following:
   ```julia
   components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end
   mixture = Dists.MixtureModel(Dists.MvNormal, components, w)
   x ~ Product(fill(mixture, N))
   ```

   Unfortunately, I still have problems with running the chain. It complains the method `_logpdf` is not implemented even though I did implement it.

Thanks for letting me know about this, I'll have a look and see if I can work it out.

In general, I think there are a lot of fundamental problems with Distributions, especially when it comes to PPL. Making this better is a lot of the motivation behind MeasureTheory. It's not yet a full workaround, but most of my energy this year has been directed toward this.

4. I could not use the function `sample`.  Issue [sample(...) does not work #293](https://github.com/cscherrer/Soss.jl/issues/293) was raised with a similar problem.

Thanks for letting me know about this. When error are you getting? I'll need to be able to reproduce the problem before I can make progress on it.

5. I am not quite sure what you meant by:
   > This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

Hamiltonian Monte Carlo (HMC) was popularized by the Stan language. It's a great way to do inference, but it only works when the sample space is unconstrained Euclidean space.

The standard way to work around this is to marginalize of the discrete parameters, and set up a bijection between the sample space and ℝⁿ.

Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.

This would be great!! Yes, we definitely need documentation, examples, tutorials, etc. The only limitation here is that I'm stretched in a few different directions, so it's hard to get everything done at once.

gzagatti commented 3 years ago

Thanks for the help again.

I have been studying the topic in more details and I developed a better understanding for transform. The TransformVariables.jl seems to be similar to Bijectors.jl.

I have created a gist with the Gaussian mixture example using different options that I have played around. The gist is a Pluto notebook so you should be able to replicate with the exact environment I am using. The version that is not commented out runs without any issues except for the last command that calls sample. It complains that this function is not defined.

As I get more familiar with PPLs, I will try to write some examples, add them to the documentation and open a PR with them.