TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
206 stars 33 forks source link

MethodError: no method matching bijector(::MixtureModel{Multivariate, Continuous, MvNormal, Float64}) #227

Open krishvishal opened 2 years ago

krishvishal commented 2 years ago

MWE:

using Bijectors, Distributions

dist = MixtureModel(MvNormal, [(ones(2), 1), (2 .* ones(2), 1)])

x = rand(dist)
b = bijector(dist)

Error:

ERROR: MethodError: no method matching bijector(::MixtureModel{Multivariate, Continuous, MvNormal, Float64})
Closest candidates are:
  bijector(::Union{Kolmogorov, BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}) at /home/.julia/packages/Bijectors/LmARY/src/transformed_distribution.jl:58
  bijector(::Union{Arcsine, Beta, Biweight, Cosine, Epanechnikov, NoncentralBeta}) at /home/.julia/packages/Bijectors/LmARY/src/transformed_distribution.jl:69
  bijector(::Union{Levy, Pareto}) at /home/.julia/packages/Bijectors/LmARY/src/transformed_distribution.jl:72
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[5]:1

It seems the bijector for Multivariate MixtureModel is not defined. Can someone please clarify this?

krishvishal commented 2 years ago

Since mixture of Dirichilet distributions lives on a simplex, so its bijector has to be a SimplexBijector.

I've defined a custom distribution with a SimplexBijector to solve this error. Similarly one can define a custom distribution with IdentityBijector for mixture of MvNormal distributions.

using Bijectors, Turing, Distributions, Random

struct CustomMixture <: ContinuousMultivariateDistribution
    a::Vector{Float64}
    b::Vector{Float64}
    weights::Vector{Float64}
end

function Base.rand(rng::Random.AbstractRNG, d::CustomMixture)
    sample = rand(rng, MixtureModel(Dirichlet, [d.a, d.b], d.weights))
    return sample
end

function Distributions.logpdf(d::CustomMixture, x::AbstractVector)
    return logpdf(MixtureModel(Dirichlet, [d.a, d.b], d.weights), x)
end

Base.length(d::CustomMixture) = length(d.a)

Bijectors.bijector(d::CustomMixture) = Bijectors.SimplexBijector{1}()