TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.03k stars 218 forks source link

Issue with hierarchical model with LKJCholesky sampling! #1807

Closed aelmokadem closed 2 years ago

aelmokadem commented 2 years ago

Hi,

I have attached below a model I am trying to adapt from a similar Stan implementation. Briefly, the model is a compartmental model for a drug that is described using an ODEProblem prob. I am trying to run a Bayesian inference on the model parameters that include fixed effects and random effects. The latter include random effects to characterize interindividual variability (IIV) between patients. Often these IIV params would be correlated so it is important to estimate the correlation as well. To characterize correlation I want to use LKJ cholesky however when I try to fit I get this error.

Screen Shot 2022-03-24 at 5 32 24 PM

I only put a snapshot of the beginning of the error but please let me know if more details are needed.

Any ideas how I can fix this model?! Thanks.

### helper functions
diag_pre_multiply(ω, η) = Diagonal(ω) * η
rep_matrix(v,n) = hcat(fill.(θ̂, nSubject)...)'

@model function fitPKPop(data, prob, nSubject, doses, times, bws)
    # priors
    ## residual error
    σ ~ truncated(Cauchy(0.0, 0.5), 0.0, 2.0)

    ## population params
    θ̂ = Vector{Float64}(undef, 3)
    θ̂[1] ~ LogNormal(log(2.0), 0.2)
    θ̂[2] ~ LogNormal(log(4.0), 0.2)
    θ̂[3] ~ LogNormal(log(35.0), 0.2)

    # IIV
    ω ~ filldist(truncated(Cauchy(0.0, 0.5), 0.0, 2.0), nIIV)

    # non-centered parameterization
    η ~ filldist(Normal(0.0, 1.0), nIIV*nSubject)
    ηᵢ = reshape(η, nIIV, nSubject)
    L ~ Distributions.LKJ(nIIV, 1)

    θᵢ = (rep_matrix(θ̂, nSubject) .* exp.(diag_pre_multiply(ω, L * ηᵢ)))'
    kaᵢ = θᵢ[:, 1]
    CLᵢ = θᵢ[:, 2] .* (bws ./ 70.0).^0.75
    Vᵢ = θᵢ[:, 3] .* (bws ./ 70.0)

    function prob_func(prob,i,repeat)
        u0_tmp = [doses[i],0.0]
        ps = [kaᵢ[i], CLᵢ[i], Vᵢ[i]]
        remake(prob, u0=u0_tmp, p=ps, saveat=times[i][:,1])
    end

    tmp_ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
    tmp_ensemble_sol = solve(tmp_ensemble_prob, Tsit5(), trajectories=length(doses)) 

    predicted = []
    for i in 1:length(nSubject)
        tmp_sol = Array(tmp_ensemble_sol[i])[2,:]
        append!(predicted, tmp_sol)
    end

    # likelihood
    for i = 1:length(predicted)
        data[i] ~ Normal(predicted[i], σ)
    end
end
devmotion commented 2 years ago

LKJCholesky is not supported yet: https://github.com/TuringLang/Turing.jl/issues/1629

aelmokadem commented 2 years ago

I see. Thanks @devmotion for the input. I did however try just LKJ and ran into a different issue

Screen Shot 2022-03-24 at 9 56 28 PM
storopoli commented 2 years ago

I am also experiencing this same error:

# define the model
@model function correlated_varying_intercept_slope_regression(X, idx, y;
                                                              predictors=size(X, 2),
                                                              N=size(X, 1),
                                                              n_gr=length(unique(idx)))
    # priors
    # Turing does not have LKJCholesky yet
    # see: https://github.com/TuringLang/Turing.jl/issues/1629
    #Ω ~ LKJCholesky(predictors, 2.0)
    Ω ~ LKJ(predictors, 2.0)
    σ ~ Exponential(1)

    # prior for variance of random intercepts and slopes
    # usually requires thoughtful specification
    τ ~ filldist(truncated(Cauchy(0, 2), 0, Inf), predictors) # group-level SDs
    γ ~ filldist(Normal(0, 5), predictors, n_gr)              # matrix of group coefficients
    Z ~ filldist(Normal(0, 1), predictors, n_gr)              # matrix of non-centered group coefficients

    # reconstruct β from Ω and τ
    #β = γ + τ .* Ω.L * Z                                     # Turing does not have LKJCholesky yet
    β = γ + Diagonal(τ) * Ω * Diagonal(τ) * Z

    # likelihood
    for i in 1:N
        y[i] ~ Normal(X[i, :] ⋅ β[:, idx[i]], σ)
    end
    return(; y, β, σ, Ω, τ, γ, Z)
end
aelmokadem commented 2 years ago

@storopoli yeah and using LKJ will actually fluctuate between the error above and this one

ERROR: DomainError with -1.0:
log will only return a complex result if called with a complex argument. Try log(Complex(x)).
Stacktrace:
  [1] throw_complex_domainerror(f::Symbol, x::Float64)
    @ Base.Math ./math.jl:33
  [2] _log(x::Float64, base::Val{:ℯ}, func::Symbol)
    @ Base.Math ./special/log.jl:292
  [3] log
    @ ./special/log.jl:257 [inlined]
  [4] log
    @ ~/.julia/packages/ForwardDiff/wAaVJ/src/dual.jl:240 [inlined]
  [5] logdet(A::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}})
    @ LinearAlgebra /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:1614
  [6] logkernel
    @ ~/.julia/packages/Distributions/HAuAd/src/matrix/lkj.jl:116 [inlined]
  [7] _logpdf
    @ ~/.julia/packages/Distributions/HAuAd/src/matrixvariates.jl:82 [inlined]
  [8] logpdf
    @ ~/.julia/packages/Distributions/HAuAd/src/common.jl:250 [inlined]
  [9] logpdf_with_trans(d::LKJ{Float64, Int64}, x::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, transform::Bool)
    @ Bijectors ~/.julia/packages/Bijectors/U0SqN/src/Bijectors.jl:136
 [10] assume(dist::LKJ{Float64, Int64}, vn::AbstractPPL.VarName{:L, Setfield.IdentityLens}, vi::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.TypedVarInfo{NamedTuple{(:σ, :k̂a, :ĈL, :V̂, :ω, :η, :L), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Truncated{Cauchy{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:k̂a, Setfield.IdentityLens}, Int64}, Vector{LogNormal{Float64}}, Vector{AbstractPPL.VarName{:k̂a, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:ĈL, Setfield.IdentityLens}, Int64}, Vector{LogNormal{Float64}}, Vector{AbstractPPL.VarName{:ĈL, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:V̂, Setfield.IdentityLens}, Int64}, Vector{LogNormal{Float64}}, Vector{AbstractPPL.VarName{:V̂, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:ω, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, Truncated{Cauchy{Float64}, Continuous, Float64}, FillArrays.Fill{Truncated{Cauchy{Float64}, Continuous, Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:ω, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:η, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:η, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:L, Setfield.IdentityLens}, Int64}, Vector{LKJ{Float64, Int64}}, Vector{AbstractPPL.VarName{:L, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Set{DynamicPPL.Selector}}}}}, ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}, Vector{Base.RefValue{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 11}}}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/R7VK9/src/context_implementations.jl:198
 [11] assume
    @ ~/.julia/packages/Turing/S4Y4B/src/inference/hmc.jl:462 [inlined]
 [12] tilde_assume
    @ ~/.julia/packages/DynamicPPL/R7VK9/src/context_implementations.jl:49 [inlined]
 [13] tilde_assume
    @ ~/.julia/packages/DynamicPPL/R7VK9/src/context_implementations.jl:46 [inlined]
 [14] tilde_assume
aelmokadem commented 2 years ago

So @sethaxen helped me make a fix for this issue. Here is his suggested updated code

    #L ~ Distributions.LKJ(nIIV, 1.)  # will give error
    #L ~ Distributions.LKJCholesky(nIIV, 1.)  # still does not work with Turing.jl

    # @sethaxen  fix
    trans = CorrCholeskyFactor(nIIV)
    L_tilde ~ filldist(Turing.Flat(), dimension(trans))
    L_U, logdetJ = transform_and_logjac(trans, L_tilde)
    Turing.@addlogprob! logpdf(LKJCholesky(nIIV, 1), Cholesky(L_U)) + logdetJ

    #θᵢ = (rep_matrix(θ̂, nSubject) .* exp.(diag_pre_multiply(ω, L * ηᵢ)))'  # replaced with following line
    θᵢ = (repeat(θ̂, 1, nSubject) .* exp.(ω .* (L_U' * (L_U * ηᵢ))))'

I really appreciate @sethaxen help with this and I will close this issue since it was resolved for me as well as @storopoli. However, I still think the LKJCholesky implementation in Turing is important as well as a more direct solution to that LKJ issue.

YSanchezAraujo commented 1 year ago

any updates on whether it is supported, and are there any example use cases?