TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Nested transformation with Shift does not work for Matrix output #191

Open mgmverburg opened 3 years ago

mgmverburg commented 3 years ago

Basically, what I want to achieve is a logit transformed outcome for example, but I want to allow covariates to have different effects, and hence shift the mean for each of the entries of the matrix. I had a way to do this before already, but this involved an arraydist with a for-loop, and when optimizing I noticed that that approach actually caused a slowdown (among other things, it made type-instabilities appear I believe when using code_warntype). Therefore, I wanted to find a way to do this in one shot. Most of the things work though, but for some reason this specific use-case seemed to not work, even though it felt like in theory it should.

To make it perhaps slightly weirder, the shift bijector does work with 2D data like in model_1 in the code below, which throws no error. However, when adding a layer like a logit transform to wrap around it, then it throws an error that I listed below the code.

using Bijectors, Turing, LinearAlgebra, using Random

M, N = 8, 20
output = rand(LogitNormal(0, 1), M, N)

@model function test_1(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(filldist(mvn, N), Bijectors.Shift(z))
end

model_1 = test_1(output, M, N)
chain_1 = sample(model_1, NUTS(0.65), 10)

@model function test_2(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    b = inv(Bijectors.Logit{2}(0.0, 1.0))
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(transformed(filldist(mvn, N), Bijectors.Shift(z)), b)
end

model_2 = test_2(output, M, N)
chain_2 = sample(model_2, NUTS(0.65), 10)
Error message ERROR: MethodError: no method matching _logabsdetjac_shift(::Array{Float64,2}, ::Array{Float64,2}, ::Val{2}) Closest candidates are: _logabsdetjac_shift(::T1, ::AbstractArray{T2,2}, ::Val{2}) where {T1<:Union{Real, AbstractArray{T,1} where T}, T2<:Real} at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:36 _logabsdetjac_shift(::Union{Tracker.TrackedArray{var"#s25",1,A} where A where var"#s25"<:Real, Tracker.TrackedReal}, ::AbstractArray{var"#s24",2} where var"#s24"<:Real, ::Val{1}) at /root/.julia/packages/Bijectors/LmARY/src/compat/tracker.jl:80 _logabsdetjac_shift(::T1, ::AbstractArray{T2,2}, ::Val{1}) where {T1<:Union{Real, AbstractArray{T,1} where T}, T2<:Real} at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:35 ... Stacktrace: [1] logabsdetjac(::Bijectors.Shift{Array{Float64,2},2}, ::Array{Float64,2}) at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:30 [2] logpdf_with_trans(::Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate}, ::Array{Float64,2}, ::Bool) at /root/.julia/packages/Bijectors/LmARY/src/Bijectors.jl:132 [3] _logpdf(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Bijectors/LmARY/src/transformed_distribution.jl:124 [4] logpdf(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Distributions/Xrm9e/src/matrixvariates.jl:164 [5] loglikelihood(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Distributions/Xrm9e/src/matrixvariates.jl:227 [6] observe(::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:148 [7] _tilde(::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:112 [8] tilde(::DynamicPPL.DefaultContext, ::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:68 [9] tilde_observe(::DynamicPPL.DefaultContext, ::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::AbstractPPL.VarName{:output,Tuple{}}, ::Tuple{}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:93 [10] #3 at ./REPL[9]:5 [inlined] [11] (::var"#3#4")(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext, ::Array{Float64,2}, ::Int64, ::Int64) at ./none:0 [12] macro expansion at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:0 [inlined] [13] _evaluate(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:150 [14] evaluate_threadsafe(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:140 [15] (::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}})(::Random._GLOBAL_RNG, ::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:94 [16] DynamicPPL.VarInfo(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/varinfo.jl:132 [17] DynamicPPL.VarInfo(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.SampleFromUniform) at /root/.julia/packages/DynamicPPL/wCsuo/src/varinfo.jl:131 [18] step(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}; resume_from::Nothing, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/sampler.jl:69 [19] macro expansion at /root/.julia/packages/AbstractMCMC/ByHEr/src/sample.jl:123 [inlined] [20] macro expansion at /root/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined] [21] (::AbstractMCMC.var"#21#22"{Bool,String,Nothing,Int64,Int64,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}},Random._GLOBAL_RNG,DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}},DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}},Int64,Int64})() at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:11 [22] with_logstate(::Function, ::Any) at ./logging.jl:408 [23] with_logger(::Function, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,AbstractMCMC.var"#1#3"{Module}},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,AbstractMCMC.var"#2#4"{Module}}}}) at ./logging.jl:514 [24] with_progresslogger(::Function, ::Module, ::Logging.ConsoleLogger) at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:34 [25] macro expansion at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:10 [inlined] [26] mcmcsample(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}, ::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type{T} where T, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}}) at /root/.julia/packages/AbstractMCMC/ByHEr/src/sample.jl:114 [27] sample(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}, ::Int64; chain_type::Type{T} where T, resume_from::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /root/.julia/packages/Turing/TbEZL/src/inference/hmc.jl:133 [28] sample at /root/.julia/packages/Turing/TbEZL/src/inference/hmc.jl:116 [inlined] [29] #sample#2 at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:142 [inlined] [30] sample at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:142 [inlined] [31] #sample#1 at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:132 [inlined] [32] sample(::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}, ::Int64) at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:132 [33] top-level scope at REPL[11]:1

So I was able to fix this (for my specific case that I encountered an error with) by simply adding: Bijectors._logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractMatrix}, T2<:Real} = zero(T2)

But I am not sure if that is the best/cleanest fix for the package as a whole, or whether this covers just 1 use-case again.

Bijectors version 0.9.7, Turing 0.16.0

torfjelde commented 3 years ago

Ah yes this is a missing definition. But you're solution is correct :+1: Once #183 has gone through, these things shouldn't happen.