Closed ClaudMor closed 2 years ago
I'd benchmark all approaches and then choose the fastest 😉 More seriously, it's really problem and model dependent.
Maybe another alternative would be to break up the priors (and p
) into multiple groups with thr same type of prior since the problems arise from priors
not being concretely typed.
Alternatively, maybe you can add a type assertion on p
such as p::Vector{Float64}
(I guess Turing might not support this) or with an additional variable _p::Vector{Float64} = p
that you use then in the ODE solver. BTW if I'm not mistaken, your type annotations such as ::Array...
in the function body are basically equivalent to convert
calls but for actual type assertions (that error if not satisfied and are guarantees to the compiler) you have to annotate the assigned variable, i.e., something like x::Vector{Float64} = ...
. The annotation ::ODEProblem
is probably not helpful since ODEProblem
is not a concrete type (generally, the output of remake
should be inferrable if the inputs such as p
are inferred). However, IIRC you can just call solve(prob1, Tsit5(); p=p, saveat=0.1)
(or with p=_p
) without constructing prob
explicitly. The Bayesian DiffEq tutorial also contains examples of how to set the sensitivity mode and AD system of the ODE solver (which is independent from the Turing settings).
Hello @devmotion ,
Thanks for your suggestions. I tried to implement them. Now the model looks like:
# define turing model
@model function fitlv2(data::Array{Float64,2}, prob1, priors::Array{Distribution{Univariate,Continuous},1})
σ ~ InverseGamma(2, 3)
p ~ arraydist(priors)
_p::Vector{Float64} = p
predicted::Vector{Vector{Float64}} = solve(prob1,Tsit5(); p = _p, saveat=0.1 ).u
for i= 1:length(predicted)
data[:,i] ~ MvNormal(predicted[i], σ)
end
end
And if I @code_warntype it,I get:
model2 = fitlv2(odedata, fast_prob1, priors)
# inspect type inference
@code_warntype model2.f(
Random.GLOBAL_RNG,
model2,
Turing.VarInfo(model2),
Turing.SampleFromPrior(),
Turing.DefaultContext(),
model2.args...,
)
Variables
#self#::Core.Compiler.Const(var"#29#30"(), false)
_rng::Core.Compiler.Const(Random._GLOBAL_RNG(), false)
_model::DynamicPPL.Model{var"#29#30",(:data, :prob1, :priors),(),(),Tuple{Array{Float64,2},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#198"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0x675c4d77, 0x05b2dfd8, 0x3d03ded5, 0x73dfb8d9, 0x864d24da)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTIIPVar#271"), Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0xda0ef279, 0xca85e7cf, 0xcf7fc331, 0x44a7018d, 0x67545bf7)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Array{Distribution{Univariate,Continuous},1}},Tuple{}}
_varinfo::DynamicPPL.VarInfo{NamedTuple{(:σ, :p),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:σ,Tuple{}},Int64},Array{InverseGamma{Float64},1},Array{DynamicPPL.VarName{:σ,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}},DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{Product{Continuous,Distribution{Univariate,Continuous},Array{Distribution{Univariate,Continuous},1}},1},Array{DynamicPPL.VarName{:p,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}
_sampler::Core.Compiler.Const(DynamicPPL.SampleFromPrior(), false)
_context::Core.Compiler.Const(DynamicPPL.DefaultContext(), false)
data::Array{Float64,2}
prob1::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#198"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0x675c4d77, 0x05b2dfd8, 0x3d03ded5, 0x73dfb8d9, 0x864d24da)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTIIPVar#271"), Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0xda0ef279, 0xca85e7cf, 0xcf7fc331, 0x44a7018d, 0x67545bf7)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}
priors::Array{Distribution{Univariate,Continuous},1}
tmpright#485::InverseGamma{Float64}
vn#487::DynamicPPL.VarName{:σ,Tuple{}}
inds#488::Tuple{}
σ::Float64
tmpright#489::Product{Continuous,Distribution{Univariate,Continuous},Array{Distribution{Univariate,Continuous},1}}
vn#491::DynamicPPL.VarName{:p,Tuple{}}
inds#492::Tuple{}
- p::Any
_p::Array{Float64,1}
predicted::Array{Array{Float64,1},1}
! @_20::Union{Nothing, Tuple{Int64,Int64}}
@_21::TypeVar
@_22::TypeVar
i::Int64
tmpright#493::MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}}
vn#495::DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}
inds#496::Tuple{Tuple{Colon,Int64}}
isassumption#497::Bool
@_28::TypeVar
vn#498::DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}
@_30::Bool
@_31::Bool
Body::Nothing
1 ── Core.NewvarNode(:(vn#487))
│ Core.NewvarNode(:(inds#488))
│ Core.NewvarNode(:(σ))
│ Core.NewvarNode(:(tmpright#489))
│ Core.NewvarNode(:(vn#491))
│ Core.NewvarNode(:(inds#492))
│ Core.NewvarNode(:(p))
│ Core.NewvarNode(:(_p))
│ Core.NewvarNode(:(predicted))
│ Core.NewvarNode(:(@_20))
│ (tmpright#485 = Main.InverseGamma(2, 3))
│ %12 = tmpright#485::Core.Compiler.Const(InverseGamma{Float64}(
invd: Gamma{Float64}(α=2.0, θ=0.3333333333333333)
θ: 3.0
)
, false)::Core.Compiler.Const(InverseGamma{Float64}(
invd: Gamma{Float64}(α=2.0, θ=0.3333333333333333)
θ: 3.0
)
, false)
│ (@_21 = Core.TypeVar(Symbol("#s196"), Distribution))
│ %14 = @_21::Core.Compiler.PartialTypeVar(var"#s196"<:Distribution, true, true)::Core.Compiler.PartialTypeVar(var"#s196"<:Distribution, true, true)
-│ %15 = Core.apply_type(Main.AbstractVector, @_21::Core.Compiler.PartialTypeVar(var"#s196"<:Distribution, true, true))::Type{AbstractArray{var"#s196"<:Distribution,1}}
│ %16 = Core.UnionAll(%14, %15)::Type{AbstractArray{var"#s196",1} where var"#s196"<:Distribution}
│ %17 = Core.apply_type(Main.Union, Distribution, %16)::Type{Union{AbstractArray{var"#s196",1} where var"#s196"<:Distribution, Distribution}}
│ %18 = (%12 isa %17)::Core.Compiler.Const(true, false)
│ %18
└─── goto #3
2 ── Core.Compiler.Const(:(Main.ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.")), false)
└─── Core.Compiler.Const(:(Main.throw(%21)), false)
3 ┄─ (vn#487 = σ)
│ (inds#488 = ())
│ (σ = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, tmpright#485::Core.Compiler.Const(InverseGamma{Float64}(
invd: Gamma{Float64}(α=2.0, θ=0.3333333333333333)
θ: 3.0
)
, false), vn#487, inds#488, _varinfo))
│ (tmpright#489 = Main.arraydist(priors))
│ %27 = tmpright#489::Product{Continuous,Distribution{Univariate,Continuous},Array{Distribution{Univariate,Continuous},1}}
│ (@_22 = Core.TypeVar(Symbol("#s195"), Distribution))
│ %29 = @_22::Core.Compiler.PartialTypeVar(var"#s195"<:Distribution, true, true)::Core.Compiler.PartialTypeVar(var"#s195"<:Distribution, true, true)
-│ %30 = Core.apply_type(Main.AbstractVector, @_22::Core.Compiler.PartialTypeVar(var"#s195"<:Distribution, true, true))::Type{AbstractArray{var"#s195"<:Distribution,1}}
│ %31 = Core.UnionAll(%29, %30)::Type{AbstractArray{var"#s196",1} where var"#s196"<:Distribution}
│ %32 = Core.apply_type(Main.Union, Distribution, %31)::Type{Union{AbstractArray{var"#s196",1} where var"#s196"<:Distribution, Distribution}}
│ %33 = (%27 isa %32)::Core.Compiler.Const(true, false)
│ %33
└─── goto #5
4 ── Core.Compiler.Const(:(Main.ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.")), false)
└─── Core.Compiler.Const(:(Main.throw(%36)), false)
5 ┄─ (vn#491 = p)
│ (inds#492 = ())
│ (p = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, tmpright#489, vn#491, inds#492, _varinfo))
-│ %41 = p::Any
│ %42 = Core.apply_type(Main.Vector, Main.Float64)::Core.Compiler.Const(Array{Float64,1}, false)
-│ %43 = Base.convert(%42, %41)::Any
│ (_p = Core.typeassert(%43, %42))
│ %45 = Main.Tsit5()::Core.Compiler.Const(Tsit5(), false)
│ %46 = (:p, :saveat)::Core.Compiler.Const((:p, :saveat), false)
│ %47 = Core.apply_type(Core.NamedTuple, %46)::Core.Compiler.Const(NamedTuple{(:p, :saveat),T} where T<:Tuple, false)
│ %48 = Core.tuple(_p, 0.1)::Core.Compiler.PartialStruct(Tuple{Array{Float64,1},Float64}, Any[Array{Float64,1}, Core.Compiler.Const(0.1, false)])
│ %49 = (%47)(%48)::Core.Compiler.PartialStruct(NamedTuple{(:p, :saveat),Tuple{Array{Float64,1},Float64}}, Any[Array{Float64,1}, Core.Compiler.Const(0.1, false)])
│ %50 = Core.kwfunc(Main.solve)::Core.Compiler.Const(DiffEqBase.var"#solve##kw"(), false)
│ %51 = (%50)(%49, Main.solve, prob1, %45)::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#198"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0x675c4d77, 0x05b2dfd8, 0x3d03ded5, 0x73dfb8d9, 0x864d24da)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTIIPVar#271"), Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0xda0ef279, 0xca85e7cf, 0xcf7fc331, 0x44a7018d, 0x67545bf7)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,ModelingToolkit.var"#f#198"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0x675c4d77, 0x05b2dfd8, 0x3d03ded5, 0x73dfb8d9, 0x864d24da)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTIIPVar#271"), Symbol("##MTKArg#267"), Symbol("##MTKArg#268"), Symbol("##MTKArg#269")),ModelingToolkit.var"#_RGF_ModTag",(0xda0ef279, 0xca85e7cf, 0xcf7fc331, 0x44a7018d, 0x67545bf7)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}
│ %52 = Base.getproperty(%51, :u)::Array{Array{Float64,1},1}
│ %53 = Core.apply_type(Main.Array, Main.Float64, 1)::Core.Compiler.Const(Array{Float64,1}, false)
│ %54 = Core.apply_type(Main.Array, %53, 1)::Core.Compiler.Const(Array{Array{Float64,1},1}, false)
│ %55 = Base.convert(%54, %52)::Array{Array{Float64,1},1}
│ (predicted = Core.typeassert(%55, %54))
│ %57 = Main.length(predicted)::Int64
│ %58 = (1:%57)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│ (@_20 = Base.iterate(%58))
│ %60 = (@_20 === nothing)::Bool
│ %61 = Base.not_int(%60)::Bool
└─── goto #17 if not %61
6 ┄─ Core.NewvarNode(:(vn#495))
│ Core.NewvarNode(:(inds#496))
│ Core.NewvarNode(:(isassumption#497))
│ %66 = @_20::Tuple{Int64,Int64}::Tuple{Int64,Int64}
│ (i = Core.getfield(%66, 1))
│ %68 = Core.getfield(%66, 2)::Int64
│ %69 = Base.getindex(predicted, i)::Array{Float64,1}
│ (tmpright#493 = Main.MvNormal(%69, σ))
│ %71 = tmpright#493::MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}}
│ (@_28 = Core.TypeVar(Symbol("#s193"), Distribution))
│ %73 = @_28::TypeVar
-│ %74 = Core.apply_type(Main.AbstractVector, @_28)::Type{AbstractArray{_A,1}} where _A
-│ %75 = Core.UnionAll(%73, %74)::Any
-│ %76 = Core.apply_type(Main.Union, Distribution, %75)::Type
│ %77 = (%71 isa %76)::Bool
└─── goto #8 if not %77
7 ── goto #9
8 ── %80 = Main.ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.")::ArgumentError
└─── Main.throw(%80)
9 ┄─ %82 = Core.tuple(Main.:(:), i)::Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])
│ %83 = Core.tuple(%82)::Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])
│ (vn#495 = (DynamicPPL.VarName)(:data, %83))
│ %85 = Core.tuple(Main.:(:), i)::Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])
│ (inds#496 = Core.tuple(%85))
│ %87 = Core.tuple(Main.:(:), i)::Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])
│ %88 = Core.tuple(%87)::Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])
│ (vn#498 = (DynamicPPL.VarName)(:data, %88))
│ %90 = (DynamicPPL.inargnames)(vn#498::Core.Compiler.PartialStruct(DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}, Any[Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])]), _model)::Core.Compiler.Const(true, false)
│ %91 = !%90::Core.Compiler.Const(false, true)
└─── goto #11 if not %91
10 ─ Core.Compiler.Const(:(@_30 = %91), false)
└─── Core.Compiler.Const(:(goto %96), false)
11 ┄ (@_30 = (DynamicPPL.inmissings)(vn#498::Core.Compiler.PartialStruct(DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}, Any[Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])]), _model))
└─── goto #13 if not @_30::Core.Compiler.Const(false, false)
12 ─ Core.Compiler.Const(:(@_31 = true), false)
└─── Core.Compiler.Const(:(goto %101), false)
13 ┄ %99 = Base.getindex(data, Main.:(:), i)::Array{Float64,1}
│ (@_31 = %99 === Main.missing)
│ (isassumption#497 = @_31::Core.Compiler.Const(false, false))
└─── goto #15 if not isassumption#497::Core.Compiler.Const(false, false)
14 ─ Core.Compiler.Const(:((DynamicPPL.tilde_assume)(_rng, _context, _sampler, tmpright#493, vn#495, inds#496, _varinfo)), false)
│ Core.Compiler.Const(:(Base.setindex!(data, %103, Main.:(:), i)), false)
└─── Core.Compiler.Const(:(goto %111), false)
15 ┄ %106 = tmpright#493::MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}}
│ %107 = Base.getindex(data, Main.:(:), i)::Array{Float64,1}
│ %108 = vn#495::Core.Compiler.PartialStruct(DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}, Any[Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])])::Core.Compiler.PartialStruct(DynamicPPL.VarName{:data,Tuple{Tuple{Colon,Int64}}}, Any[Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])])
│ %109 = inds#496::Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])::Core.Compiler.PartialStruct(Tuple{Tuple{Colon,Int64}}, Any[Core.Compiler.PartialStruct(Tuple{Colon,Int64}, Any[Core.Compiler.Const(Colon(), false), Int64])])
│ (DynamicPPL.tilde_observe)(_context, _sampler, %106, %107, %108, %109, _varinfo)
│ (@_20 = Base.iterate(%58, %68))
│ %112 = (@_20 === nothing)::Bool
│ %113 = Base.not_int(%112)::Bool
└─── goto #17 if not %113
16 ─ goto #6
17 ┄ return
I checked that i get the same result even if i use a concrete type for the priors
argument ( therefore modifying the priors array to be , for example, an array of truncated normals).
Do you think it is ok now?
Hello,
Following the performance tips, I'm trying to let the compiler know all types in my turing model in advance.
Anyway, I'm facing problems trying to make it infer the type of arrays ( here
p
) that appear inside the model in the formWhat follows is a more explicit example inspired by the tutorials, modified to reflect the structure of the more complex model I'm trying to build.
Then if I inspect the type inference, I get:
( I did my best trying to color the text as
@code_warntype
would, with the caveat that if a line is colored, then @code_warntype would have colored only the type after the::
. I didn't color all lines not to ruin indentation, since I'm using this trick to color the text. I just colored the first ones )So how may i fix the type inference?
I know that I could be using a for loop instead of the
arraydist
as seen here, but the model we are developing is a bit complicated ( ~ 60 parameters), so it may benefit from switching toTrackerAD
orZygote
backend, which doesn't like loops.Side question: in case we had to choose between
TrackerAD
orZygote
TrackerAD
orZygote
ReverseDiff
Which would you suggest?
Thank you very much for your attention.