Closed bgroenks96 closed 3 years ago
This is unfortunately a result of our documentation not being completely up-to-date (we're currently on improving this, e.g. https://github.com/TuringLang/Turing.jl/issues/1474 and https://github.com/TuringLang/TuringTutorials/issues/86).
But there is a predict
method which probably does exactly what you want: https://github.com/TuringLang/Turing.jl/blob/cb58871e321ef058d86b51adfe602a562bd690f4/src/inference/Inference.jl#L477-L546
And for future reference, there's also the generated_quantities
for cases where you want to look at the predictive posterior for quantities that are not directly sampled.
Again, this should really be made explicit in the documentation and is going to be the case once we've updated stuff, ref the above issues.
Ok! I think that pretty much solves this issue then. Thanks!
Quick note for posterity: if you have multi-dimensional y
with missing values, you will get an output from predict
that is difficult to transform back into the expected shape of your output (e.g. n x m
for n data points with m dimensions).
You can solve this by usiing include_all=true
, grouping, and reshaping:
pred_chain = predict(model(X,Matrix{Union{Missing,Float64}}(undef, size(X)...)), chain, include_all=true) |>
p -> group(p, :y)
preds = reshape(Array(pred_chain), (500,n,m))
# output: S x n x m, where S is the number of chain samples
@torfjelde predict
seems to not work when using models that call external functions to generate predictions, e.g. with Bayesian differential equations. It just returns a chain object with the same samples repeated over and over, presumably because it reuses the last call to solve
. I can try to prepare an example for you, but it should be reproduceable from the Bayesian diffeq example in the Turing docs. Just try calling predict
using a missing
data model.
Should I create a new issue for this?
Is this on the most recent version of Turing? I could potentially be related to https://github.com/TuringLang/Turing.jl/issues/1464, which was recently fixed.
But if it still persists even in the most recent version, if you could open a separate issue, that would be awesome!
EDIT: No need to read the issue I referenced btw. It's more the fix we did in https://github.com/TuringLang/DynamicPPL.jl/pull/191 that might be have also fixed this issue. Essentially, in certain cases we would do copy-by-reference rather than copy-by-value, and so after running the model once, it would fill the missing
array with actual numbers, and then the second time we called the model in predict
, it would no longer be missing
.
@torfjelde I'm pretty sure I found this issue while my Turing version was downgraded (0.5), although I wasn't aware of that.
I am rebuilding my sysimage at the moment, and I will check after it's done!
@torfjelde Ok, so the problem I described above with values being reused does not appear to exist in 0.15.4. That's the good news!
However, the results don't really make sense.
If I compute the posterior predictive myself (excluding observation noise) by simply running the diffeq model with each parameter setting from the posterior, I get exactly what I would expect. An ensemble of model runs.
If I use predict
, the result looks more like a typical sample transition chain...
Could you clarify what exactly predict is doing? It calls transitions_from_chain
, right? Should this return the numerical model outputs at each transition in the case of a Bayesian differential equation model?
I'll try to set up the Lotka-Volterra example from the docs and see if I can reproduce it there.
If I use predict, the result looks more like a typical sample transition chain...
Is this not what you want? I'm a bit confused I guess. You get back a Chains
when calling predict
, right? But you really just want an array of model runs?
I'll try to set up the Lotka-Volterra example from the docs and see if I can reproduce it there.
That would be awesome! But pseudo-code would also be useful as a starter, as I feel like I'm possible misunderstanding your intention here.
@torfjelde
Yeah sorry, I didn't explain that well.... the Chains
object is fine, that's what I expect. I just mean that when you plot the results it doesn't look like the samples came from the physical model. It looks like they are just random samples.
Here's an example for an SEIR COVID-19 model that I whipped up recently based on a tutorial for DifferentialEquations.jl
:
@model function seir(x0, y; datavars=[1], tspan=(0.0,365.0))
nvars = length(datavars)
σ_inv ~ truncated(Normal(5.2,0.5),1.0,Inf)
γ_inv ~ truncated(Normal(18.0,3.0),1.0,Inf)
R₀_bar ~ truncated(Normal(2.0,0.5),0.1,Inf)
δ ~ Beta(1,100)
η ~ Beta(1,10)
ν ~ filldist(InverseGamma(2,3),nvars) # noise variance
params = [1.0/σ_inv, 1.0/γ_inv, R₀_bar, η, δ]
prob = ODEProblem(F, x0, tspan, params)
pred = solve(prob, Tsit5(), saveat=1.0)
for j = 1:nvars
for i = 1:size(y,1)
y[i,j] ~ Normal(pred[datavars[j],i], ν[j])
end
end
return y
end
I can build an approximate posterior predictive by simply iterating over the posterior samples and running the model on each one:
chain_df = DataFrame(chain)
invert(x) = 1.0 ./ x
params = select(chain_df, :σ_inv => invert => :σ, :γ_inv => invert => :γ, :R₀_bar, :η, :δ)
results = []
for p in eachrow(params)
prob = ODEProblem(F, x_0, (0.0,size(data,1)), p)
push!(results, solve(prob, Tsit5(), saveat=1.0))
end
That produces this plot:
plot(cases_mid, label=nothing, xlabel="Days since July 1st", ylabel="Number of cases")
plot!(cases_upper, label=nothing, c="transparent", fill=cases_lower, fillcolor="orange", fillalpha=0.4)
...which looks reasonable. If I use predict:
model_test = seir(x_0, Matrix{Union{Float64,Missing}}(missing,size(data_normalized)...); datavars=[6,7], tspan=(0.0,size(data,1)))
pred_chain = predict(model_test, chain)
preds = reshape(Array(pred_chain), (size(chain,1),size(data_normalized)...))
pred_cases = preds[:,:,1]
pred_deaths = preds[:,:,2]
cases_lower = mapslices(x -> quantile(x, 0.025), pred_cases; dims=1)[1,:]*N
cases_upper = mapslices(x -> quantile(x, 0.975), pred_cases; dims=1)[1,:]*N
cases_mid = mapslices(x -> quantile(x, 0.5), pred_cases; dims=1)[1,:]*N
deaths_lower = mapslices(x -> quantile(x, 0.025), pred_deaths; dims=1)[1,:]*N
deaths_upper = mapslices(x -> quantile(x, 0.975), pred_deaths; dims=1)[1,:]*N
deaths_mid = mapslices(x -> quantile(x, 0.5), pred_deaths; dims=1)[1,:]*N;
then I get:
which looks more like Turing just sampled from the prior/likelihood and didn't run the SEIR model at all...
Maybe I'm just doing something wrong? This is just an example I had on hand. If it's too opaque, let me know and I can go grab the Lotka-Volterra one from the docs.
Just to show that it's not my post-processing code, but actually the values returned by predict
, here's the output of pred_chain
:
Hmm, yeah this is weird. We've run into this issue before but then we fixed it in DynamicPPL, so I'm somewhat confused why it's now back. It might just be that we're not lower-bounding the correct version of DPPL or something. I'll have a look at this again after dinner.
And thank you for the very thorough replies/troubleshooting! Really helpful:)
So it turns out it's not a "bug" in Turing, per se. The issue is the way you reconstruct the trajectories from the predicted Chains
.
I used the Lotka-Volterra Model example from the DiffEq tutorial (https://github.com/TuringLang/TuringTutorials/blob/master/10_diffeq.ipynb), and "reproduced" the issue. So I have this chain:
Chains MCMC chain (1000×17×3 Array{Float64,3}):
Iterations = 1:1000
Thinning interval = 1
Chains = 1, 2, 3
Samples per chain = 1000
parameters = α, β, γ, δ, σ
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
α 1.7502 0.2828 0.0052 0.0504 6.4874 3.8125
β 1.3580 0.3780 0.0069 0.0689 6.2154 5.6250
γ 3.1057 0.3732 0.0068 0.0608 7.7069 2.2009
δ 1.1922 0.3713 0.0068 0.0671 6.3035 4.7589
σ 1.2600 0.6357 0.0116 0.1173 6.0817 9.5661
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
α 1.4601 1.5396 1.5954 2.0444 2.3049
β 1.0015 1.0752 1.1334 1.8323 1.9901
γ 2.6141 2.8291 2.9727 3.4039 3.8948
δ 0.8454 0.9189 0.9709 1.6097 1.9036
σ 0.7423 0.7989 0.8389 2.0741 2.3172
and this chain returned from predict
:
Chains MCMC chain (1000×202×3 Array{Float64,3}):
Iterations = 1:1000
Thinning interval = 1
Chains = 1, 2, 3
Samples per chain = 1000
parameters = data[1,1], data[1,2], data[1,3], data[1,4], data[1,5], data[1,6], data[1,7], data[1,8], data[1,9], data[1,10], data[1,11], data[1,12], data[1,13], data[1,14], data[1,15], data[1,16], data[1,17], data[1,18], data[1,19], data[1,20], data[1,21], data[1,22], data[1,23], data[1,24], data[1,25], data[1,26], data[1,27], data[1,28], data[1,29], data[1,30], data[1,31], data[1,32], data[1,33], data[1,34], data[1,35], data[1,36], data[1,37], data[1,38], data[1,39], data[1,40], data[1,41], data[1,42], data[1,43], data[1,44], data[1,45], data[1,46], data[1,47], data[1,48], data[1,49], data[1,50], data[1,51], data[1,52], data[1,53], data[1,54], data[1,55], data[1,56], data[1,57], data[1,58], data[1,59], data[1,60], data[1,61], data[1,62], data[1,63], data[1,64], data[1,65], data[1,66], data[1,67], data[1,68], data[1,69], data[1,70], data[1,71], data[1,72], data[1,73], data[1,74], data[1,75], data[1,76], data[1,77], data[1,78], data[1,79], data[1,80], data[1,81], data[1,82], data[1,83], data[1,84], data[1,85], data[1,86], data[1,87], data[1,88], data[1,89], data[1,90], data[1,91], data[1,92], data[1,93], data[1,94], data[1,95], data[1,96], data[1,97], data[1,98], data[1,99], data[1,100], data[1,101], data[2,1], data[2,2], data[2,3], data[2,4], data[2,5], data[2,6], data[2,7], data[2,8], data[2,9], data[2,10], data[2,11], data[2,12], data[2,13], data[2,14], data[2,15], data[2,16], data[2,17], data[2,18], data[2,19], data[2,20], data[2,21], data[2,22], data[2,23], data[2,24], data[2,25], data[2,26], data[2,27], data[2,28], data[2,29], data[2,30], data[2,31], data[2,32], data[2,33], data[2,34], data[2,35], data[2,36], data[2,37], data[2,38], data[2,39], data[2,40], data[2,41], data[2,42], data[2,43], data[2,44], data[2,45], data[2,46], data[2,47], data[2,48], data[2,49], data[2,50], data[2,51], data[2,52], data[2,53], data[2,54], data[2,55], data[2,56], data[2,57], data[2,58], data[2,59], data[2,60], data[2,61], data[2,62], data[2,63], data[2,64], data[2,65], data[2,66], data[2,67], data[2,68], data[2,69], data[2,70], data[2,71], data[2,72], data[2,73], data[2,74], data[2,75], data[2,76], data[2,77], data[2,78], data[2,79], data[2,80], data[2,81], data[2,82], data[2,83], data[2,84], data[2,85], data[2,86], data[2,87], data[2,88], data[2,89], data[2,90], data[2,91], data[2,92], data[2,93], data[2,94], data[2,95], data[2,96], data[2,97], data[2,98], data[2,99], data[2,100], data[2,101]
internals =
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
data[1,1] 0.9618 1.3837 0.0253 0.0232 2979.2866 1.0001
data[1,2] 1.0479 1.4190 0.0259 0.0206 2572.4246 0.9993
data[1,3] 1.1143 1.4031 0.0256 0.0268 2859.7547 0.9994
data[1,4] 1.2308 1.4245 0.0260 0.0216 3250.5532 0.9996
data[1,5] 1.3571 1.4057 0.0257 0.0320 2765.7726 1.0014
data[1,6] 1.5369 1.3856 0.0253 0.0243 2896.6665 0.9996
data[1,7] 1.6964 1.4063 0.0257 0.0266 2958.3855 0.9991
data[1,8] 1.9617 1.4418 0.0263 0.0300 2895.5749 1.0032
data[1,9] 2.1746 1.4417 0.0263 0.0261 3093.7371 0.9995
data[1,10] 2.5078 1.4531 0.0265 0.0256 2762.0360 1.0002
data[1,11] 2.8704 1.4542 0.0265 0.0304 2754.0771 1.0029
data[1,12] 3.2314 1.4000 0.0256 0.0302 2878.4954 1.0017
data[1,13] 3.6497 1.4432 0.0263 0.0317 3150.8124 1.0035
data[1,14] 3.9541 1.4485 0.0264 0.0334 2561.9226 1.0054
data[1,15] 4.3780 1.4690 0.0268 0.0705 136.1604 1.0370
data[1,16] 4.6045 1.6718 0.0305 0.1544 23.6200 1.1744
data[1,17] 4.8337 1.9946 0.0364 0.2446 13.1787 1.3798
data[1,18] 5.0234 2.3339 0.0426 0.3447 9.1446 1.7371
data[1,19] 5.2197 2.7004 0.0493 0.4252 8.1008 2.0079
data[1,20] 5.2321 2.9892 0.0546 0.4873 7.5920 2.2256
data[1,21] 5.1108 3.0526 0.0557 0.5028 7.4459 2.3274
data[1,22] 4.6500 2.8477 0.0520 0.4608 7.6882 2.1735
data[1,23] 3.9549 2.4854 0.0454 0.3786 8.5979 1.8555
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
179 rows omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
data[1,1] -2.0858 0.2854 0.9968 1.6506 3.8896
data[1,2] -1.9736 0.3404 1.0566 1.7578 4.1194
data[1,3] -2.0937 0.4180 1.1202 1.8241 4.0490
data[1,4] -1.8285 0.5285 1.2156 1.9113 4.4069
data[1,5] -1.7131 0.6559 1.3779 2.0828 4.4362
data[1,6] -1.5760 0.8314 1.5295 2.2477 4.5621
data[1,7] -1.2695 1.0010 1.6866 2.3985 4.8273
data[1,8] -1.1320 1.2362 1.9626 2.6368 5.1356
data[1,9] -0.8101 1.4318 2.1641 2.9058 5.3575
data[1,10] -0.6652 1.7555 2.4671 3.2163 5.8269
data[1,11] -0.2795 2.1189 2.8538 3.5823 6.2452
data[1,12] 0.2534 2.5006 3.2040 3.9085 6.4196
data[1,13] 0.6215 2.8904 3.6006 4.3557 7.0326
data[1,14] 0.6667 3.2576 4.0026 4.7282 7.0018
data[1,15] 0.7178 3.7362 4.5004 5.2242 7.0474
data[1,16] 0.1213 3.9832 4.8968 5.6332 7.1415
data[1,17] -0.3515 4.1338 5.3484 6.1485 7.3915
data[1,18] -0.7110 3.8251 5.8347 6.6163 7.8908
data[1,19] -1.0973 3.3998 6.2834 7.1162 8.4171
data[1,20] -1.5547 3.0825 6.5293 7.3731 8.5543
data[1,21] -1.6606 2.7883 6.4689 7.2792 8.5541
data[1,22] -1.6817 2.5682 5.8590 6.6887 7.8936
data[1,23] -2.0757 2.5429 4.8758 5.6853 6.9547
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
179 rows omitted
If I do the same approach reshaping as you do, i.e.
chain_predict_array = reshape(Array(chain_predict), :, size(odedata)...);
I get the following predictions:
The correct way of converting the chain is:
# Convert into array with correct shape
syms = reshape(chain_predict.name_map.parameters, size(odedata_missing')...)
tmp = cat([Array(chain_predict[syms[:, i]]) for i = 1:size(syms, 2)]...; dims=3)
chain_predict_array = permutedims(tmp, (1, 3, 2)) # (num_samples, num_times, dim)
which gives me the following predictions:
Hopefully this makes more sense:)
@torfjelde Thanks for the example!
Perhaps we could add a built-in function for post-processing the chain into an array of predictions? This seems like the most common and natural use case of predict
, so I don't think it's very user friendly to require this tricky reshaping!
I also initially suspected that maybe this was the problem. But if you look at the chain statistics in the post above, they also don't make sense. The number of cases, according to the SEIR model, cannot be negative. So the fact that many of the predictive samples have negative means and quantiles around zero is highly suspect.
After digging a bit more into my SEIR COVID-19 example (this isn't my target use-case by the way, I was just playing with Turing/DifferentialEquations.jl), it seems there is also a problem with the likelihood.
The data (y
) parameter here is a population proportion. Thus, even a small amount of observation variance (e.g. sigma = 1.0E-3
) causes huge fluctuations in the rescaled output when the population N
is large, as is the case with entire countries!
Using very small variances on the Normal
distribution of < 1.0E-5 seems to cause numerical errors.
This isn't really an issue I've run into before, but I suppose the solution is to use a more suitable likelihood. Maybe a Beta
distribution?
I would appreciate your insight as this would help me verify that Turing is indeed working correctly on my example!
Perhaps we could add a built-in function for post-processing the chain into an array of predictions? This seems like the most common and natural use case of predict, so I don't think it's very user friendly to require this tricky reshaping!
I think "most common" is maybe not quite true outside of "time-series" like this, e.g. in most cases my data[i, j]
passed to predict will only take i = 1
and j = 1:2
rather than i = 1:num_obs
so as to produce a single prediction for each sample in chain
. But yeah, for time-series (and other similar use-cases) it's indeed very annoying :confused:
Unfortunately it's very difficult to do correctly + we sort of want MCMCChains
to be agnostic to such, as it makes it much more versatile.
WIth that being said, I 100% agree that we at least should provide functionality for converting these "simple" scenarios like the one you're referring to. I think the best approach as of right now is the following.
First we reshape names(chain)
as we want the resulting samples to be, e.g. in my case I have
names(chain_predict)
202-element Array{Symbol,1}:
Symbol("data[1,1]")
Symbol("data[1,2]")
Symbol("data[1,3]")
Symbol("data[1,4]")
Symbol("data[1,5]")
Symbol("data[1,6]")
Symbol("data[1,7]")
Symbol("data[1,8]")
Symbol("data[1,9]")
Symbol("data[1,10]")
Symbol("data[1,11]")
Symbol("data[1,12]")
Symbol("data[1,13]")
⋮
Symbol("data[2,90]")
Symbol("data[2,91]")
Symbol("data[2,92]")
Symbol("data[2,93]")
Symbol("data[2,94]")
Symbol("data[2,95]")
Symbol("data[2,96]")
Symbol("data[2,97]")
Symbol("data[2,98]")
Symbol("data[2,99]")
Symbol("data[2,100]")
Symbol("data[2,101]")
So I do:
syms = reshape(names(chain_predict), :, 2)
to get
101×2 Array{Symbol,2}:
Symbol("data[1,1]") Symbol("data[2,1]")
Symbol("data[1,2]") Symbol("data[2,2]")
Symbol("data[1,3]") Symbol("data[2,3]")
Symbol("data[1,4]") Symbol("data[2,4]")
Symbol("data[1,5]") Symbol("data[2,5]")
Symbol("data[1,6]") Symbol("data[2,6]")
Symbol("data[1,7]") Symbol("data[2,7]")
Symbol("data[1,8]") Symbol("data[2,8]")
Symbol("data[1,9]") Symbol("data[2,9]")
Symbol("data[1,10]") Symbol("data[2,10]")
Symbol("data[1,11]") Symbol("data[2,11]")
Symbol("data[1,12]") Symbol("data[2,12]")
Symbol("data[1,13]") Symbol("data[2,13]")
⋮
Symbol("data[1,90]") Symbol("data[2,90]")
Symbol("data[1,91]") Symbol("data[2,91]")
Symbol("data[1,92]") Symbol("data[2,92]")
Symbol("data[1,93]") Symbol("data[2,93]")
Symbol("data[1,94]") Symbol("data[2,94]")
Symbol("data[1,95]") Symbol("data[2,95]")
Symbol("data[1,96]") Symbol("data[2,96]")
Symbol("data[1,97]") Symbol("data[2,97]")
Symbol("data[1,98]") Symbol("data[2,98]")
Symbol("data[1,99]") Symbol("data[2,99]")
Symbol("data[1,100]") Symbol("data[2,100]")
Symbol("data[1,101]") Symbol("data[2,101]")
Combine this with permutedims(syms, (2, 1))
to get the wanted shape of (2, 101)
. Once we have this, we add the following utility methods:
# This makes it so that `AxisArrays.jl` now respects the ordering of
# the symbols when indexing, e.g. `A[[:a, :b]]` and `A[[:b, :a]]` will
# return the reversed ordering + now stuff like `A[[:a, :a]]` will also
# have the expected behavior.
function AxisArrays.axisindexes(::Type{AxisArrays.Categorical}, ax::AbstractVector, idx::AbstractVector)
# res = findall(in(idx), ax) # <= original impl
res = mapreduce(vcat, idx) do i
findfirst(isequal(i), ax)
end
length(res) == length(idx) || throw(ArgumentError("index $(setdiff(idx,ax)) not found"))
res
end
# Essentially just collapses the `syms` and then reshapes.
# The ordering of `syms` is now preserved.
function Base.Array(
chains::MCMCChains.Chains,
syms::AbstractArray{Symbol},
args...;
kwargs...
)
# HACK> Index into `AxisArray` directly rather than chain because
# chain will not respect ordering of indices like `AxisArray` does.
a = Array(chains.value[:, vec(syms), :], args...; kwargs...)
return reshape(a, size(a, 1), size(syms)..., size(a, 3))
end
Equipped with this, we can do the following in my example from above:
A = Array(chain_predict, permutedims(syms, (2, 1)))
let chain_idx = 1, sample_idx = 1, i = 1, j = 3, sym = Symbol("data[$i,$j]")
(
A[chain_idx, i, j, sample_idx],
chain_predict[chain_idx, sym, sample_idx],
sym
)
end
resulting in
(-0.4236570209344528, -0.4236570209344528, Symbol("data[1,3]"))
as wanted! :tada:
Btw, I'll make a PR for AxisArrays.jl
to add the change above in the package as this seems like it would be a useful feature to have. Might be a reason why they haven't done it though since they seem aware that this would be a nice feature, e.g. https://github.com/JuliaArrays/AxisArrays.jl/blob/9b91d546b28d96cd980e0a86d9c860c3689881d7/src/indexing.jl#L140. There's a def a performance implication, but unclear to me if that's really every going to be a big bottleneck.
I also initially suspected that maybe this was the problem. But if you look at the chain statistics in the post above, they also don't make sense. The number of cases, according to the SEIR model, cannot be negative. So the fact that many of the predictive samples have negative means and quantiles around zero is highly suspect.
If I understand you correctly, you're essentially questioning whether or not a Normal
likelihood is correct for these problems where the observations are actually bounded, right?
Let me preface by saying that @yebai will likely have a much better answer this question, so hopefully he can chime in + correct me.
I think this very much depends on the scenario. Say if you have a bunch of measurement observations, all of which are far from the boundary of the domain, e.g. the population proportions never get close to 0 or 1 but stay somewhere near 0.5, then a Normal
likelihood with a sufficient small variance is likely to provide a good approximation of the underlying noise model (in the large data regime). But if you don't have a lot of measurements + it makes sense for the random variables to take values near the boundaries, then this causes issues (Which is made clear in the predictions above where you observed negative values for something that shouldn't be).
So for the question of what to when the Normal
doesn't make sense I would say is highly problem-dependent, and ideally chosen by someone who has intricate knowledge about how the observations were gathered. In the case of the SEIR model with it's data, I'm not entirely certain what to suggest, unfortunately :confused:
For practical purposes when making predictions if the inference looks good (despite the model misspecification), you could do something like:
TruncatedNormal
which depends on how far the mean-value is from the boundary to ensure the constraints are respected.predict
but instead just inspect the sampled solutions (from the approximate posterior/chain) to the system rather than the actual sampled observations.Thanks for the detailed response, @torfjelde !
I think your AxisArrays
solution is reasonable.
If I understand you correctly, you're essentially questioning whether or not a Normal likelihood is correct for these problems where the observations are actually bounded, right?
Yes, that is part of the problem. I suppose TruncatedNormal
would solve that. The main issue is with numerical issues for small population proportions.
So yes, the state variables must sum to 1 in an SEIR model, and they are normalized by some constant N
being the number of individuals in the population.
The problem is that, for large populations, N
is on the order of tens or hundreds of millions. Thus, the state variables E
and I
are often very, very small, i.e. 1.0E-3 or less. Thus, realistic variance in the cumulative number of cases would be something on the order of 1.0E-5 or 1.0E-6. This causes numerical instability in the MCMC sampler and sample
fails with:
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq /home/brian/.julia/packages/OrdinaryDiffEq/OK16j/src/solve.jl:482
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ DiffEqBase /home/brian/.julia/packages/DiffEqBase/cuMMc/src/integrator_interface.jl:322
followed by a BoundsError
.
Ignoring the variance in the predictions would be OK for estimating the mean, but kind of nullifies part of the point of Bayesian parameter estimation, which is to push epistemic and aleatoric uncertainty forward to the predictions to quantify uncertainty in the predictions.
In the case where there is a relationship between the different variables, e.g. in SIR model you want the variables to sum to 1, you could clamp and normalize.
In my opinion, ideally this would be encoded in the model though. For instance, one could model observations with a Dirichlet distribution whose mean is the ODE solution and that could be more or less concentrated around this mean, depending on an inferred (or user-provided) scaling parameter. I am not familiar with the literature about inference of SIR models and their variants but I would assume that someone has tried and used something similar.
Good point @devmotion , the Dirchlet could work. I'll look into that.
Thus, the state variables E and I are often very, very small, i.e. 1.0E-3 or less.
I guess to fix this issue you could potentially scale with the population number (or I suppose any other constant) and observe this instead. This should work since the reason for requiring such a small variance is that we're on the boundary region rather than the actual noise having such a low variance. But yeah, this is getting real hacky. And adding to the issue of using a Normal
likelihood in the SEIR model: if you're observeing all (S, E, I, R)
, you're essentially observing one of the variables twice due to the fact that once you know (S, E, I)
then R
is fully determined.
In my opinion, ideally this would be encoded in the model though.
Definitively agree with this, but less clear to me if a Dirichlet
likelihood "makes sense" with how it allocates probability mass.
Another option is to use a NegativeBinomial
with the ODE solution as a mean (https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html#sampling-distribution) and observe the actual counts rather than the proportions. Note that here they're using a different parameterization of NegativeBinomial
than the one available in Distributions.jl (example impl: https://github.com/cambridge-mlg/Covid19/blob/3b1644701ef32063a65fbbc72332ba0eaa22f82b/src/utils.jl#L3-L39). Here I believe it would make the most sense to observe only 3 of the 4 variables in the SEIR model due to reason over-specification I mentioned above.
Though the discussion is somewhat old already, I would add that the SEIR model in your ODE solver is already defined such that S+E+I+R=N, where N is the population size. For the observations we don't usually need to worry about this conservation requirement, mostly because we are not able to observe all compartments at once anyway. In most of the studies that I am aware of, the model is only informed by the case and death counts. However, the latter is not part of your model yet and to the best of my knowledge, most studies use a negative Binomial as mentioned correctly by Tor ( see for example Gibson et al., Flaxman et al. and the R-package epidemia ).
PS. maybe we could write up a simple tutorial similar the one in Stan as mentioned by Tor - I really like it. PPS. There is also a pretty cool implementation of the Flaxman et al. model in Turing: Covid19. This is a great starting point even though the model is discrete in time.
PS. maybe we could write up a simple tutorial similar the one in Stan as mentioned by Tor - I really like it.
A tutorial would be dope! We're currently rewamping the tutorial-system for Turing.jl (TuringTutorials.jl), and it's getting real close to be done. Once that's gone through, we should def look into making a slightly more detailed tutorial for these sorts of problems.
Really nice! Building on your Package Covid19, we have recently implemented a very similar model by Unwin et al. and explore applications for Denmark. So let me know if I can contribute anything
Thanks @andreaskoher and @torfjelde for the helpful comments. I'm totally not an epidemiologist, and I was just using this as a (somewhat) gentle way of introducing myself to Turing + ODEs. It interested me more than Lotka-Volterra ;)
It would be really nice if we could make it easier to sample from the posterior predictive distribution. It's technically possible right now (I think, please check my example below), but it's a bit of a pain.
Consider the linear regression example in the documentation:
The documentation gives the following function for estimating the posterior mean:
But this is less than ideal because we are basically reproducing part of the model in a separate function. It also ignores observation noise.
We can get proper samples from the posterior predictive by modifying the model spec:
But this is somewhat cumbersome. Could we make the
@model
macro implicitly define the random variables in the generated function signature? And maybe add some utility function for sampling from the chain?Please let me know if I'm missing something here! Maybe there is already an easier way to do this that I am not aware of...?
Note that this is related to a suggestion in #638 "posterior predictive checks"