TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Bijector for MatrixNormal #237

Open RodrigoZepeda opened 1 year ago

RodrigoZepeda commented 1 year ago

Hi! I'm trying to sample the following Turing Model that uses a MatrixNormal distribution:

using Random, Turing, Bijectors
Random.seed!(123)

#Estimate a MatrixNormal as simulated here
U  = rand(LKJ(2, 0.5))
V  = rand(LKJ(2, 0.5))
Uₐ = rand(LKJ(2, 0.5))
Vₐ = rand(LKJ(2, 0.5))
Asample = rand(MatrixNormal(zeros(Float64, 2, 2), U, V))

#Create the model
@model function estimateA(A, U, V, Uₐ, Vₐ)
    mu ~ MatrixNormal(zeros(Float64,size(A,1), size(A,2)), Uₐ, Vₐ)
    A  ~ MatrixNormal(mu, U, V)
end

#Estimate!
model  = estimateA(Asample, U, V, Uₐ, Vₐ);
chains = sample(model, NUTS(), 100);

however I get the following error:

ERROR: MethodError: no method matching bijector(::MatrixNormal{Float64, Matrix{Float64}, PDMats.PDMat{Float64, Matrix{Float64}}, PDMats.PDMat{Float64, Matrix{Float64}}})
Closest candidates are:
  bijector(::Union{Kolmogorov, BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:58
  bijector(::Union{Arcsine, Beta, Biweight, Cosine, Epanechnikov, NoncentralBeta}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:69
  bijector(::Union{Levy, Pareto}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:72

The same error happens with the following code:

dist = MatrixNormal(zeros(2,2), rand(LKJ(2, 0.5)), rand(LKJ(2, 0.5)))
b     = bijector(dist)

I'm relatively new to Turing so maybe my diagnosis is not correct but it seems to me that the Bijectors package is lacking a definition for the MatrixNormal.

ParadaCarleton commented 1 year ago

@torfjelde

torfjelde commented 1 year ago

Ah, yes we're missing a definition of

bijector(d::MatrixNormal) = Identity{2}()

I'll make a PR but you can just add this overload in the meantime.