TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Missing implementation of `Bijectors.bijector` for `arraydist` distributions. #294

Open soldasim opened 10 months ago

soldasim commented 10 months ago

It appears some implementations of Bijectors.bijector for distributions of type VectorOfMultivariate and MatrixOfUnivariate are missing.

Example:

using Turing

function script()
    dist = arraydist(
        fill(
            Product(fill(Normal(), 2)),  # does not work
            # MvLogNormal(zeros(2), ones(2)),  # does not work
            # MvNormal(zeros(2), ones(2)),  # works
            2
        )
    )

    @model function model()
        x ~ dist
    end

    iters = 10
    chain = sample(model(), NUTS(10, 0.65), iters)
    @show chain
end
julia> script()
ERROR: MethodError: no method matching bijector(::DistributionsAD.VectorOfMultivariate{Continuous, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Vector{Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}}})

Closest candidates are:
  bijector(::Union{Kolmogorov, BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull})
   @ Bijectors C:\Users\sheld\.julia\packages\Bijectors\QhObI\src\transformed_distribution.jl:102
  bijector(::Union{Arcsine, Beta, Biweight, Cosine, Epanechnikov, NoncentralBeta})
   @ Bijectors C:\Users\sheld\.julia\packages\Bijectors\QhObI\src\transformed_distribution.jl:113
  bijector(::Union{Levy, Pareto})
   @ Bijectors C:\Users\sheld\.julia\packages\Bijectors\QhObI\src\transformed_distribution.jl:116
  ...

I am not confident enough in my understanding of bijectors to make this into a PR.

EDIT: Added MvLogNormal to the example.

soldasim commented 10 months ago

Additionally, the bijector for MvLogNormal seems to be broken.

When running the example above with MvLogNormal, I get the following:

julia> script()
ERROR: UndefVarError: `Log` not defined
Stacktrace:
  [1] bijector(d::DistributionsAD.VectorOfMultivariate{Continuous, MvLogNormal{Float64, PDMats.PDiagMat{Float64, Vector{Float64}}, Vector{Float64}}, Vector{MvLogNormal{Float64, PDMats.PDiagMat{Float64, Vector{Float64}}, Vector{Float64}}}})
    @ BijectorsDistributionsADExt C:\Users\sheld\.julia\packages\Bijectors\QhObI\ext\BijectorsDistributionsADExt.jl:64