Open gzagatti opened 3 years ago
Thanks for letting me know about this. The stack trace from xform
includes this line:
[2] xform(d::Distributions.Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}, _data::NamedTuple{(), Tuple{}})
@ Soss ~/git/Soss.jl/src/primitives/xform.jl:79
Following that takes you to
function xform(d, _data::NamedTuple)
if hasmethod(support, (typeof(d),))
return asTransform(support(d))
end
error("Not implemented:\nxform($d)")
end
The problem here is that Distributions.Dirichlet
has no support
method, so it falls through and throws the error. So you're right that the fix is to add this method.
xform
is kind of a legacy name. It was originally going to be transform
, but that name was already taken in TransformVariables.jl. But since this was built, I'm realizing this should have just been called as
, since it has the same functionality as that function from TransformVariables. Docs on that are here:
https://tamaspapp.eu/TransformVariables.jl/dev/#The-as-constructor-and-aggregations
Also, in the current setup, the _data
argument is only used when you have nested models. I'll be cleaning up the dispatch patterns for this, but for a quick fix let's just go with it.
Anyway, the missing method is
Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TransformVariables.UnitSimplex(length(d.alpha))
But this doesn't fix everything, because you still have
z ~ For(N) do _ Distributions.Categorical(w) end
This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.
We'll be adding ways to make this easier in MeasureTheory, but for now in Distributions, say you have
julia> μ = rand(Normal() |> iid(3))
3-element Vector{Float64}:
1.0510484386874308
-0.8007745046155319
0.48629964893183536
Then you can do using FillArrays, MappedArrays
and then
julia> paramvec = mappedarray(μ) do μj begin (Fill(μj, 2), 1.) end end
3-element mappedarray(var"#7#8"(), ::Vector{Float64}) with eltype Tuple{Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}:
(Fill(1.0510484386874308, 2), 1.0)
(Fill(-0.8007745046155319, 2), 1.0)
(Fill(0.48629964893183536, 2), 1.0)
These are the mixture components, which you can combine like
julia> Dists.MixtureModel(Dists.MvNormal, paramvec)
MixtureModel{Distributions.MvNormal}(K = 3)
components[1] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(1.0510484386874308, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
components[2] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-0.8007745046155319, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
components[3] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.48629964893183536, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
This uses equal weights, but you can change that by adding another parameter.
Anyway, this works but it's not pretty. We're working on making it much easier to stay in MeasureTheory for all of this.
using MeasureTheory
using Soss
using SampleChainsDynamicHMC
import Distributions
using FillArrays
using LinearAlgebra
m = @model N begin
σ0 ~ Lebesgue(ℝ)
μ0 ~ Lebesgue(ℝ)
α ~ Lebesgue(ℝ₊)
K = 2
μ ~ Normal(μ0, σ0) |> iid(K)
w ~ Distributions.Dirichlet(K, abs(α))
xdist = Dists.MixtureModel(Dists.Normal, μ, w)
x ~ Dists.MatrixReshaped(Dists.Product(Fill(xdist, K*N)), K, N)
end
using TransformVariables
const TV = TransformVariables
Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TV.UnitSimplex(length(d.alpha))
prior_data = predict(m(N=30), (N=30, σ0=1., μ0=0., α=1.))
# data generation with assumption on μ and w
predx = predictive(m, :μ, :w)
data = predict(m(N=30), (μ=[-3.5, 0.0], w=[0.5, 0.5]))
# estimating the posterior
posterior = m(N=30)|(x=data.x,)
sample(posterior, dynamichmc())
@cscherrer many thanks for the detailed answer.
I made some further explorations on my own and had a few issues:
I went through the transform documentation. I do not quite understand the implementation. For instance with asℝ
, I do not quite get why they are using the exponential distribution.
for some reason Fill
inside of the proposed model tends to replicate a single draw from the mixed distribution K*N
rather than performing K*N
draws. When I plotted the samples from prior_data
I basically got two points. If I switch to base fill
the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.
I attempted a different variation of the model as following. Since the Distributions.jl
package does not have a definition for the Product
of a multivariate, I implemented a basic version to get the work done:
struct Product{
S<:Distributions.ValueSupport,
T<:Distributions.MultivariateDistribution{S},
V<:AbstractVector{T},
} <: Distributions.MultivariateDistribution{S}
v::V
function Product(v::V) where
V<:AbstractVector{T} where
T<:Distributions.MultivariateDistribution{S} where
S<:Distributions.ValueSupport
return new{S, T, V}(v)
end
end
Base.length(d::Product) = length(d.v)
function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}}
return eltype(T)
end
function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector)
broadcast!(dn->rand(rng, dn), x, d.v)
end
function _logpdf(d::Product, x::AbstractVector)
sum(n -> logpdf(d.v[n], x[n]), 1:length(d))
end
function Distributions.rand(rng::Distributions.AbstractRNG, s::Product)
_rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s)))
end
I then redifined the model as following:
components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end
mixture = Dists.MixtureModel(Dists.MvNormal, components, w)
x ~ Product(fill(mixture, N))
Unfortunately, I still have problems with running the chain. It complains the method _logpdf
is not implemented even though I did implement it.
I could not use the function sample
. Issue #293 was raised with a similar problem.
I am not quite sure what you meant by:
This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.
Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.
@cscherrer many thanks for the detailed answer.
I made some further explorations on my own and had a few issues:
1. I went through the transform documentation. I do not quite understand the implementation. For instance with `asℝ`, I do not quite get why they are using the exponential distribution.
Sorry, I don't understand. Can you point me to a line?
2. for some reason `Fill` inside of the proposed model tends to replicate a single draw from the mixed distribution `K*N` rather than performing `K*N` draws. When I plotted the samples from `prior_data` I basically got two points. If I switch to base `fill` the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.
Ah ok, I think the issue here is that there's no way to tell FillArrays that sampling is nondeterministic. Ideally Distributions would account for this, but it seems they don't. I guess this is the point on filldist
in DistributionsAD. And come to think of it, you probably need that anyway since it will make gradients much more efficient for Distributions.
In MeasureTheory we'll have all of this built in. Currently we don't have any custom AD, but the implementations are also much simpler, so AD should have an easier time of it. We'll be adding more optimized methods as we go.
3. I attempted a different variation of the model as following. Since the `Distributions.jl` package does not have a definition for the [`Product` of a multivariate](https://github.com/JuliaStats/Distributions.jl/blob/59df675409a7e2490e4a45edd32c0267df435c55/src/multivariate/product.jl), I implemented a basic version to get the work done: ```julia struct Product{ S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}, V<:AbstractVector{T}, } <: Distributions.MultivariateDistribution{S} v::V function Product(v::V) where V<:AbstractVector{T} where T<:Distributions.MultivariateDistribution{S} where S<:Distributions.ValueSupport return new{S, T, V}(v) end end Base.length(d::Product) = length(d.v) function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}} return eltype(T) end function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector) broadcast!(dn->rand(rng, dn), x, d.v) end function _logpdf(d::Product, x::AbstractVector) sum(n -> logpdf(d.v[n], x[n]), 1:length(d)) end function Distributions.rand(rng::Distributions.AbstractRNG, s::Product) _rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s))) end ``` I then redifined the model as following: ```julia components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end mixture = Dists.MixtureModel(Dists.MvNormal, components, w) x ~ Product(fill(mixture, N)) ``` Unfortunately, I still have problems with running the chain. It complains the method `_logpdf` is not implemented even though I did implement it.
Thanks for letting me know about this, I'll have a look and see if I can work it out.
In general, I think there are a lot of fundamental problems with Distributions, especially when it comes to PPL. Making this better is a lot of the motivation behind MeasureTheory. It's not yet a full workaround, but most of my energy this year has been directed toward this.
4. I could not use the function `sample`. Issue [sample(...) does not work #293](https://github.com/cscherrer/Soss.jl/issues/293) was raised with a similar problem.
Thanks for letting me know about this. When error are you getting? I'll need to be able to reproduce the problem before I can make progress on it.
5. I am not quite sure what you meant by: > This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.
Hamiltonian Monte Carlo (HMC) was popularized by the Stan language. It's a great way to do inference, but it only works when the sample space is unconstrained Euclidean space.
The standard way to work around this is to marginalize of the discrete parameters, and set up a bijection between the sample space and ℝⁿ.
Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.
This would be great!! Yes, we definitely need documentation, examples, tutorials, etc. The only limitation here is that I'm stretched in a few different directions, so it's hard to get everything done at once.
Thanks for the help again.
I have been studying the topic in more details and I developed a better understanding for transform. The TransformVariables.jl
seems to be similar to Bijectors.jl
.
I have created a gist with the Gaussian mixture example using different options that I have played around. The gist is a Pluto notebook so you should be able to replicate with the exact environment I am using. The version that is not commented out runs without any issues except for the last command that calls sample
. It complains that this function is not defined.
As I get more familiar with PPLs, I will try to write some examples, add them to the documentation and open a PR with them.
I got quite excited about the
Soss.jl
presentation in JuliaConn 2021 that I decided to give it a go by implementing a mixture of Gaussian model. I borrowed the example from Turing.jl for comparison purposes. Please let me know if there is anything that could be improved.Unfortunately, I was not able to estimate the posterior distribution as an error is raised when I call
tr = xform(posterior)
. Apparently,xform
is not defined for the Dirichlet distribution.As I understand, not all distributions have been implemented directly in
Soss.jl
. If that is something within my capabilities, I could try to implement it. However, I have no clue whatxform
is doing and I have not been able to find a lot of documentation on it.In any case, thanks for putting the package together.