TuringLang / TuringGLM.jl

Bayesian Generalized Linear models using `@formula` syntax.
https://turinglang.org/TuringGLM.jl/dev
MIT License
70 stars 7 forks source link

Make TuringGLM,jl work with ReverseDiff #58

Closed sreedta8 closed 1 year ago

sreedta8 commented 1 year ago

Hi

This issue is based on extensive discussion that took place here: Turing Samples Slowly #1851. One of the key observations was that:

  1. Turing based Bayesian models can be specified with AD (automatic differentiation) being either forwarddiff or reversediff. For hierarchical models, it turns out the posterior sampling is several orders faster if we use ReverseDiff (7-9 minutes) than the default ForwardDiff (7-10 hours) for the specific data and models I was testing extensively across Turing, TuringGLM, Brms, and PyMC
  2. TuringGLM models testing showed that it only works with ForwardDiff (7-10 hours to sample) and gives errors with ReverseDiff. It would be wonderful to configure TuringGLM to work with ReverseDiff
sreedta8 commented 1 year ago

@storopoli Hi Jose this is Sree. Sorry I could not join you at the JuliaCon 2022. I have seen your introduction to Julia video and I liked it a lot. I have since shared your video with a number of my team members who I'm encouraging to use Julia and Turing & TuringGLM. Let me know if you have anything for me to test in terms of the new version of the TuringGLM that can work with ReverseDiff. Thanks again for your help back in July!

storopoli commented 1 year ago

Here is a MWE using this hierarchical model from the test/ suite:

julia> using ReverseDiff, Memoization, TuringGLM, DataFrames, CSV

julia> cheese = CSV.read(download("https://github.com/TuringLang/TuringGLM.jl/raw/main/data/cheese.csv"), DataFrame);

julia> f = @formula(y ~ (1 | cheese) + background);

julia> m = turing_model(f, cheese);
The idx are Dict{String1, Int64}("B" => 2, "A" => 1, "C" => 3, "D" => 4)

julia> Turing.setadbackend(:reversediff)
:reversediff

julia> Turing.setrdcache(true)
true

julia> chn = sample(m, NUTS(), 1)
Sampling 100%|████████████████████████████████████████████████| Time: 0:00:00
ERROR: TrackedArrays do not support setindex!
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] setindex!(::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ::Int64)
    @ ReverseDiff ~/.julia/packages/ReverseDiff/5MMPp/src/tracked.jl:378
  [3] macro expansion
    @ ./broadcast.jl:961 [inlined]
  [4] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [5] copyto!
    @ ./broadcast.jl:960 [inlined]
  [6] copyto!
    @ ./broadcast.jl:913 [inlined]
  [7] materialize!
    @ ./broadcast.jl:871 [inlined]
  [8] materialize!(dest::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(+), Tuple{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}})
    @ Base.Broadcast ./broadcast.jl:868
  [9] (::TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}})(__model__::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.TypedVarInfo{NamedTuple{(:α, :β, :σ, :τ, :zⱼ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:α, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, TDist{Float64}}}, Vector{AbstractPPL.VarName{:α, Setfield.IdentityLens}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:β, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:β, Setfield.IdentityLens}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:τ, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, Truncated{TDist{Float64}, Continuous, Float64}}}, Vector{AbstractPPL.VarName{:τ, Setfield.IdentityLens}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Vector{Set{DynamicPPL.Selector}}}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{Base.RefValue{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, __context__::DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}, y::Vector{Int64}, X::Matrix{Float64}, predictors::Int64, idxs::Vector{Int64}, n_gr::Int64, intercept_ranef::Vector{String}, μ_X::Int64, σ_X::Int64, prior::CustomPrior, residual::Float64)
    @ TuringGLM ~/.julia/packages/TuringGLM/s2Pou/src/turing_model.jl:186
 [10] macro expansion
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:493 [inlined]
 [11] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:476 [inlined]
 [12] evaluate_threadsafe!!
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:467 [inlined]
 [13] evaluate!!
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:402 [inlined]
 [14] evaluate!!
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:415 [inlined]
 [15] evaluate!!
    @ ~/.julia/packages/DynamicPPL/1qg3U/src/model.jl:423 [inlined]
 [16] (::Turing.LogDensityFunction{DynamicPPL.TypedVarInfo{NamedTuple{(:α, :β, :σ, :τ, :zⱼ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:α, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, TDist{Float64}}}, Vector{AbstractPPL.VarName{:α, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:β, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:β, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:τ, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, Truncated{TDist{Float64}, Continuous, Float64}}}, Vector{AbstractPPL.VarName{:τ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext})(θ::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
    @ Turing ~/.julia/packages/Turing/szPqN/src/Turing.jl:38
 [17] logdensity
    @ ~/.julia/packages/Turing/szPqN/src/Turing.jl:42 [inlined]
 [18] Fix1
    @ ./operators.jl:1081 [inlined]
 [19] ReverseDiff.GradientTape(f::Base.Fix1{typeof(LogDensityProblems.logdensity), Turing.LogDensityFunction{DynamicPPL.TypedVarInfo{NamedTuple{(:α, :β, :σ, :τ, :zⱼ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:α, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, TDist{Float64}}}, Vector{AbstractPPL.VarName{:α, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:β, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:β, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:τ, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, Truncated{TDist{Float64}, Continuous, Float64}}}, Vector{AbstractPPL.VarName{:τ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext}}, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/5MMPp/src/api/tape.jl:199
 [20] ReverseDiff.GradientTape(f::Function, input::Vector{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/5MMPp/src/api/tape.jl:198
 [21] _compiledtape
    @ ~/.julia/packages/LogDensityProblems/b1j6d/src/AD_ReverseDiff.jl:37 [inlined]
 [22] _compiledtape(ℓ::Turing.LogDensityFunction{DynamicPPL.TypedVarInfo{NamedTuple{(:α, :β, :σ, :τ, :zⱼ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:α, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, TDist{Float64}}}, Vector{AbstractPPL.VarName{:α, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:β, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:β, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:τ, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, Truncated{TDist{Float64}, Continuous, Float64}}}, Vector{AbstractPPL.VarName{:τ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext}, #unused#::Val{true}, #unused#::Nothing)
    @ LogDensityProblems ~/.julia/packages/LogDensityProblems/b1j6d/src/AD_ReverseDiff.jl:35
 [23] #ADgradient#53
    @ ~/.julia/packages/LogDensityProblems/b1j6d/src/AD_ReverseDiff.jl:31 [inlined]
 [24] ADgradient
    @ ~/.julia/packages/Turing/szPqN/src/essential/ad.jl:116 [inlined]
 [25] ADgradient
    @ ~/.julia/packages/Turing/szPqN/src/essential/ad.jl:82 [inlined]
 [26] initialstep(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, spl::DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, vi::DynamicPPL.TypedVarInfo{NamedTuple{(:α, :β, :σ, :τ, :zⱼ), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:α, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, TDist{Float64}}}, Vector{AbstractPPL.VarName{:α, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:β, Setfield.IdentityLens}, Int64}, Vector{Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:β, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:τ, Setfield.IdentityLens}, Int64}, Vector{LocationScale{Float64, Continuous, Truncated{TDist{Float64}, Continuous, Float64}}}, Vector{AbstractPPL.VarName{:τ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:zⱼ, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}; init_params::Nothing, nadapts::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Turing.Inference ~/.julia/packages/Turing/szPqN/src/inference/hmc.jl:162
 [27] step(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, spl::DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}; resume_from::Nothing, init_params::Nothing, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/1qg3U/src/sampler.jl:104
 [28] macro expansion
    @ ~/.julia/packages/AbstractMCMC/fnRmh/src/sample.jl:120 [inlined]
 [29] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [30] (::AbstractMCMC.var"#21#22"{Bool, String, Nothing, Int64, Int64, Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}}, Random._GLOBAL_RNG, DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, Int64, Int64})()
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/fnRmh/src/logging.jl:12
 [31] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging ./logging.jl:511
 [32] with_logger(f::Function, logger::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger, AbstractMCMC.var"#1#3"{Module}}, LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger, AbstractMCMC.var"#2#4"{Module}}}})
    @ Base.CoreLogging ./logging.jl:623
 [33] with_progresslogger(f::Function, _module::Module, logger::Logging.ConsoleLogger)
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/fnRmh/src/logging.jl:36
 [34] macro expansion
    @ ~/.julia/packages/AbstractMCMC/fnRmh/src/logging.jl:11 [inlined]
 [35] mcmcsample(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/fnRmh/src/sample.jl:111
 [36] sample(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}}, N::Int64; chain_type::Type, resume_from::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Turing.Inference ~/.julia/packages/Turing/szPqN/src/inference/hmc.jl:133
 [37] sample
    @ ~/.julia/packages/Turing/szPqN/src/inference/hmc.jl:103 [inlined]
 [38] #sample#2
    @ ~/.julia/packages/Turing/szPqN/src/inference/Inference.jl:145 [inlined]
 [39] sample
    @ ~/.julia/packages/Turing/szPqN/src/inference/Inference.jl:138 [inlined]
 [40] #sample#1
    @ ~/.julia/packages/Turing/szPqN/src/inference/Inference.jl:135 [inlined]
 [41] sample(model::DynamicPPL.Model{TuringGLM.var"#normal_model_ranef#16"{Int64, Int64, CustomPrior, Vector{String}, Int64, Vector{Int64}}, (:y, :X, :predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (:predictors, :idxs, :n_gr, :intercept_ranef, :μ_X, :σ_X, :prior, :residual), (), Tuple{Vector{Int64}, Matrix{Float64}, Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, Tuple{Int64, Vector{Int64}, Int64, Vector{String}, Int64, Int64, CustomPrior, Float64}, DynamicPPL.DefaultContext}, alg::NUTS{Turing.Essential.ReverseDiffAD{true}, (), AdvancedHMC.DiagEuclideanMetric}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/szPqN/src/inference/Inference.jl:129
 [42] top-level scope
    @ REPL[21]:1
storopoli commented 1 year ago

We need to somehow get rid the indexing set by idxs. This is what TuringGLM.jl creates for a hierarchical varying-intercept model:

@model function normal_model_ranef(
               y,
               X;
               predictors=size(X, 2),
               idxs=idxs,
               n_gr=2,
           )
               α ~ TDist(3)
               β ~ filldist(TDist(3), predictors)
               σ ~ Exponential(1 / std(y))
               μ = α .+ X * β
               τ ~ mad(y) * truncated(TDist(3), 0, Inf)
               zⱼ ~ filldist(Normal(), n_gr)
               αⱼ = zⱼ .* τ
               μ .+= αⱼ[idxs]
               y ~ MvNormal(μ, σ^2 * I)
               return (; α, β, σ, τ, zⱼ, αⱼ, y)
           end

We need to get rid of μ .+= αⱼ[idxs], this is the culprit for ReverseDiff.jl nor working as an AD backend.

@devmotion and @yebai any suggestions?

devmotion commented 1 year ago

The problem is not idxs but updating mu in-place, it seems? So just computing mu in one step could fix the issue, I assume. It should also not cause additional allocations it seems due to broadcast fusion.

devmotion commented 1 year ago

BTW Memoization was removed from Turing and loading it does not have any effect anymore.

devmotion commented 1 year ago

The problem is not idxs but updating mu in-place, it seems? So just computing mu in one step could fix the issue, I assume. It should also not cause additional allocations it seems due to broadcast fusion.

Indeed, that seems to fix it. With

function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{Normal})
    idxs = first(idx)
    n_gr = length(unique(first(idx)))
    @model function normal_model_ranef(
        y,
        X;
        predictors=size(X, 2),
        idxs=idxs,
        n_gr=n_gr,
        intercept_ranef=intercept_ranef,
        μ_X=μ_X,
        σ_X=σ_X,
        prior=prior,
        residual=1 / std(y),
    )
        α ~ prior.intercept
        β ~ filldist(prior.predictors, predictors)
        σ ~ Exponential(residual)
        if isempty(intercept_ranef)
            μ = α .+ X * β
        else
            τ ~ mad(y) * truncated(TDist(3); lower=0)
            zⱼ ~ filldist(Normal(), n_gr)
            μ = α .+ zⱼ[idxs] .* τ .+ X * β
        end
        #TODO: implement random-effects slope
        y ~ MvNormal(μ, σ^2 * I)
        return nothing
    end
end

I get

julia> using ReverseDiff, TuringGLM, DataFrames, CSV

julia> cheese = CSV.read(download("https://github.com/TuringLang/TuringGLM.jl/raw/main/data/cheese.csv"), DataFrame);

julia> f = @formula(y ~ (1 | cheese) + background);

julia> m = turing_model(f, cheese);
The idx are Dict{String1, Int64}("B" => 2, "A" => 1, "C" => 3, "D" => 4)

julia> Turing.setadbackend(:reversediff)
:reversediff

julia> Turing.setrdcache(true)
true

julia> chn = sample(m, NUTS(), 10)
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/david/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/david/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/david/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/david/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/david/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
┌ Info: Found initial step size
└   ϵ = 0.00078125
Sampling 100%|██████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (10×20×1 Array{Float64, 3}):

Iterations        = 6:1:15
Number of chains  = 1
Samples per chain = 10
Wall duration     = 0.87 seconds
Compute duration  = 0.87 seconds
parameters        = α, β[1], σ, τ, zⱼ[1], zⱼ[2], zⱼ[3], zⱼ[4]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float64 

           α    2.7325    0.7725     0.2443    0.4965    2.5849    1.4197        2.9711
        β[1]    0.2953    0.5236     0.1656    0.3035    1.7245    2.2898        1.9821
           σ   35.1595    1.4233     0.4501    0.7508    1.9840    1.7890        2.2804
           τ    2.3849    1.3548     0.4284    0.6482    1.9051    1.8808        2.1898
       zⱼ[1]    0.9952    0.7816     0.2472    0.4714    3.3547    1.2561        3.8560
       zⱼ[2]    0.9223    0.3719     0.1176    0.2105    2.1059    1.6947        2.4206
       zⱼ[3]    1.1382    0.1736     0.0549    0.0584    9.4017    1.0150       10.8066
       zⱼ[4]    1.0728    0.4561     0.1442    0.3015    2.5534    1.4949        2.9349

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           α    1.2934    2.2874    3.1240    3.2205    3.3384
        β[1]   -0.5268   -0.0644    0.2689    0.7529    0.9727
           σ   33.2810   34.3821   34.9830   35.7114   37.7687
           τ    0.7510    1.4941    2.0956    3.2957    4.7483
       zⱼ[1]   -0.6246    0.8976    1.3934    1.4366    1.4876
       zⱼ[2]    0.3115    0.7134    0.8837    1.2528    1.3842
       zⱼ[3]    0.9338    1.0144    1.1191    1.2107    1.4069
       zⱼ[4]    0.2969    0.8210    1.2888    1.3944    1.4218
storopoli commented 1 year ago

That's very clever. I tried creating a μ = α .+ αⱼ[idxs] .+ X * β inside the model. But I got an error.

sreedta8 commented 1 year ago

@devmotion @storopoli

how do I access and install the corrected version of the TuringGLM library that works with ReverseDiff?

storopoli commented 1 year ago

Once this https://github.com/JuliaRegistries/General/pull/67851 merges you can update TuringGLM to have your model run with ReverseDiff.jl