Open tiemvanderdeure opened 11 months ago
This is unfortunately an issue with ReverseDiff.jl that is non-trivial to resolve. It comes from the fact that ReverseDiff.jl does not work well with "structural" arrays, e.g. PDMat
or Cholesky
.
It might be possible to re-use some internal methods we are using to avoid these issues, e.g. https://github.com/TuringLang/Bijectors.jl/blob/04b79dd46eca8cea2f988348c47bd5e720a2b9a4/ext/BijectorsReverseDiffExt.jl#L222-L230
That is, the following works on my end:
julia> using Turing, LinearAlgebra, ReverseDiff
julia> Turing.setadbackend(:reversediff);
julia> Turing.setrdcache(true);
julia> @model function model_with_arraydist()
stds ~ arraydist([truncated(Normal(0, x); lower = 0.0) for x in [1,2,3]])
F ~ LKJCholesky(3, 3.0)
Sigma = Bijectors.pd_from_lower(stds .* F.L)
end
model_with_arraydist (generic function with 2 methods)
julia> mean(sample(model_with_arraydist(), NUTS(0.65), 1000))
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:34
Mean
parameters mean
Symbol Float64
stds[1] 0.9353
stds[2] 1.2679
stds[3] 3.4438
F.L[1,1] 1.0000
F.L[2,1] -0.1247
F.L[3,1] 0.1726
F.L[2,2] 0.9243
F.L[3,2] 0.0878
F.L[3,3] 0.9025
julia> @model function model_with_filldist()
stds ~ filldist(truncated(Normal(0, 0.5); lower = 0.0), 3)
F ~ LKJCholesky(3, 3.0)
Sigma = Bijectors.pd_from_lower(stds .* F.L)
end
model_with_filldist (generic function with 2 methods)
julia> mean(sample(model_with_filldist(), NUTS(0.65), 1000))
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:30
Mean
parameters mean
Symbol Float64
stds[1] 0.3786
stds[2] 0.2528
stds[3] 0.3489
F.L[1,1] 1.0000
F.L[2,1] 0.1822
F.L[3,1] -0.0433
F.L[2,2] 0.9302
F.L[3,2] 0.1201
F.L[3,3] 0.8273
Bijectors.pd_from_lower
is not documented as it's not meant to be a public-facing method, but it effectively takes in a Matrix
, assumes it's lower triangular, and the constructs a PDMat
from this, "hiding" the PDMat(Cholesky(...))
from ReverseDiff.jl.
Also, just for the record, both of the examples you give fail on my end (which is what I expected) :confused:
I'll try with Bijectors.pd_from_lower
, thanks!
Also, just for the record, both of the examples you give fail on my end (which is what I expected) 😕
I made a typo in my MWE, but the following code definitely runs for me (I'm on Turing v0.92.2 and ReverseDiff v1.15.1)
using Turing, PDMats, LinearAlgebra, ReverseDiff
Turing.setadbackend(:reversediff)
@model function model_with_filldist()
stds ~ filldist(truncated(Normal(0, 0.5); lower = 0.0), 3)
F ~ LKJCholesky(3, 3.0)
Sigma = PDMat(Cholesky(LowerTriangular(stds .* F.L)))
end
sample(model_with_filldist(), HMC(0.1, 10), 50)
Using Bijectors.pd_from_lower
I quickly ran into numerical stability problems, especially with bigger covariance matrices. E.g.
@model function model_with_filldist(i)
stds ~ filldist(truncated(Normal(0, 0.5); lower = 0.0), i)
F ~ LKJCholesky(i, 3.0)
Sigma = Bijectors.pd_from_lower(stds .* F.L)
Y ~ MvNormal(Sigma)
end
mean(sample(model_with_filldist(3), NUTS(0.65), 1000)) # Works
mean(sample(model_with_filldist(5), NUTS(0.65), 1000)) # Errors
Where the error i get is PosDefException: matrix is not Hermitian; Cholesky factorization failed.
My quick fix is to wrap the Sigma
in a Hermitian
, which runs (but probably isn't great for ReverseDiff
either?)
@model function model_with_filldist2(i)
stds ~ filldist(truncated(Normal(0, 0.5); lower = 0.0), i)
F ~ LKJCholesky(i, 3.0)
Sigma = Hermitian(Bijectors.pd_from_lower(stds .* F.L))
Y ~ MvNormal(Sigma)
end
mean(sample(model_with_filldist2(3), NUTS(0.65), 1000)) # Works
mean(sample(model_with_filldist2(5), NUTS(0.65), 1000)) # Works
I bumped into some more shenanigans with
reversediff
and covariance matrices. Generating a vector of numbers usingarraydist
, multiplying that with a matrix from LKJCholesky, and passing it back toPDMat
to generate a covariance matrix fails.The following model works with
forwarddiff
but errors withreversediff
:The same code with
filldist
instead ofarraydist
does not errorStacktrace: