TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 218 forks source link

Make it easier to sample from posterior predictive #1475

Closed bgroenks96 closed 3 years ago

bgroenks96 commented 3 years ago

It would be really nice if we could make it easier to sample from the posterior predictive distribution. It's technically possible right now (I think, please check my example below), but it's a bit of a pain.

Consider the linear regression example in the documentation:

# Bayesian linear regression.
@model function linear_regression(x, y)
    # Set variance prior.
    σ₂ ~ truncated(Normal(0, 100), 0, Inf)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(nfeatures, sqrt(10))

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sqrt(σ₂))
end

The documentation gives the following function for estimating the posterior mean:

# Make a prediction given an input vector.
function prediction(chain, x)
    p = get_params(chain[200:end, :, :])
    targets = p.intercept' .+ x * reduce(hcat, p.coefficients)'
    return vec(mean(targets; dims = 2))
end

But this is less than ideal because we are basically reproducing part of the model in a separate function. It also ignores observation noise.

We can get proper samples from the posterior predictive by modifying the model spec:

# Bayesian linear regression.
@model function linear_regression(x, y; σ₂=missing,intercept=missing,coefficients=missing)
    # Set variance prior.
    σ₂ ~ truncated(Normal(0, 100), 0, Inf)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(nfeatures, sqrt(10))

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sqrt(σ₂))
end
chain = ... # assume we have a sampled chain
pp_samples = []
for sample in eachrow(DataFrame(chain))
    coeffs = select(sample, ["coefficients[$i]" for i in 1:nvars]...)
    model = linear_regression(x,missing; σ₂=sample["σ₂"],intercept=sample["intercept"],coefficients=coeffs)
    push!(pp_samples, model())
end

But this is somewhat cumbersome. Could we make the @model macro implicitly define the random variables in the generated function signature? And maybe add some utility function for sampling from the chain?

Please let me know if I'm missing something here! Maybe there is already an easier way to do this that I am not aware of...?

Note that this is related to a suggestion in #638 "posterior predictive checks"

torfjelde commented 3 years ago

This is unfortunately a result of our documentation not being completely up-to-date (we're currently on improving this, e.g. https://github.com/TuringLang/Turing.jl/issues/1474 and https://github.com/TuringLang/TuringTutorials/issues/86).

But there is a predict method which probably does exactly what you want: https://github.com/TuringLang/Turing.jl/blob/cb58871e321ef058d86b51adfe602a562bd690f4/src/inference/Inference.jl#L477-L546

And for future reference, there's also the generated_quantities for cases where you want to look at the predictive posterior for quantities that are not directly sampled.

Again, this should really be made explicit in the documentation and is going to be the case once we've updated stuff, ref the above issues.

bgroenks96 commented 3 years ago

Ok! I think that pretty much solves this issue then. Thanks!

bgroenks96 commented 3 years ago

Quick note for posterity: if you have multi-dimensional y with missing values, you will get an output from predict that is difficult to transform back into the expected shape of your output (e.g. n x m for n data points with m dimensions).

You can solve this by usiing include_all=true, grouping, and reshaping:

pred_chain = predict(model(X,Matrix{Union{Missing,Float64}}(undef, size(X)...)), chain, include_all=true) |>
             p -> group(p, :y)
preds = reshape(Array(pred_chain), (500,n,m))
# output: S x n x m, where S is the number of chain samples
bgroenks96 commented 3 years ago

@torfjelde predict seems to not work when using models that call external functions to generate predictions, e.g. with Bayesian differential equations. It just returns a chain object with the same samples repeated over and over, presumably because it reuses the last call to solve. I can try to prepare an example for you, but it should be reproduceable from the Bayesian diffeq example in the Turing docs. Just try calling predict using a missing data model.

Should I create a new issue for this?

torfjelde commented 3 years ago

Is this on the most recent version of Turing? I could potentially be related to https://github.com/TuringLang/Turing.jl/issues/1464, which was recently fixed.

But if it still persists even in the most recent version, if you could open a separate issue, that would be awesome!

EDIT: No need to read the issue I referenced btw. It's more the fix we did in https://github.com/TuringLang/DynamicPPL.jl/pull/191 that might be have also fixed this issue. Essentially, in certain cases we would do copy-by-reference rather than copy-by-value, and so after running the model once, it would fill the missing array with actual numbers, and then the second time we called the model in predict, it would no longer be missing.

bgroenks96 commented 3 years ago

@torfjelde I'm pretty sure I found this issue while my Turing version was downgraded (0.5), although I wasn't aware of that.

I am rebuilding my sysimage at the moment, and I will check after it's done!

bgroenks96 commented 3 years ago

@torfjelde Ok, so the problem I described above with values being reused does not appear to exist in 0.15.4. That's the good news!

However, the results don't really make sense.

If I compute the posterior predictive myself (excluding observation noise) by simply running the diffeq model with each parameter setting from the posterior, I get exactly what I would expect. An ensemble of model runs.

If I use predict, the result looks more like a typical sample transition chain...

Could you clarify what exactly predict is doing? It calls transitions_from_chain, right? Should this return the numerical model outputs at each transition in the case of a Bayesian differential equation model?

I'll try to set up the Lotka-Volterra example from the docs and see if I can reproduce it there.

torfjelde commented 3 years ago

If I use predict, the result looks more like a typical sample transition chain...

Is this not what you want? I'm a bit confused I guess. You get back a Chains when calling predict, right? But you really just want an array of model runs?

I'll try to set up the Lotka-Volterra example from the docs and see if I can reproduce it there.

That would be awesome! But pseudo-code would also be useful as a starter, as I feel like I'm possible misunderstanding your intention here.

bgroenks96 commented 3 years ago

@torfjelde

Yeah sorry, I didn't explain that well.... the Chains object is fine, that's what I expect. I just mean that when you plot the results it doesn't look like the samples came from the physical model. It looks like they are just random samples.

Here's an example for an SEIR COVID-19 model that I whipped up recently based on a tutorial for DifferentialEquations.jl:

@model function seir(x0, y; datavars=[1], tspan=(0.0,365.0))
    nvars = length(datavars)
    σ_inv ~ truncated(Normal(5.2,0.5),1.0,Inf)
    γ_inv ~ truncated(Normal(18.0,3.0),1.0,Inf)
    R₀_bar ~ truncated(Normal(2.0,0.5),0.1,Inf)
    δ ~ Beta(1,100)
    η ~ Beta(1,10)
    ν ~ filldist(InverseGamma(2,3),nvars) # noise variance
    params = [1.0/σ_inv, 1.0/γ_inv, R₀_bar, η, δ]
    prob = ODEProblem(F, x0, tspan, params)
    pred = solve(prob, Tsit5(), saveat=1.0)
    for j = 1:nvars
        for i = 1:size(y,1)
            y[i,j] ~ Normal(pred[datavars[j],i], ν[j])
        end
    end
    return y
end

I can build an approximate posterior predictive by simply iterating over the posterior samples and running the model on each one:

chain_df = DataFrame(chain)
invert(x) = 1.0 ./ x
params = select(chain_df, :σ_inv => invert => :σ, :γ_inv => invert => :γ, :R₀_bar, :η, :δ)
results = []
for p in eachrow(params)
    prob = ODEProblem(F, x_0, (0.0,size(data,1)), p)
    push!(results, solve(prob, Tsit5(), saveat=1.0))
end

That produces this plot:

plot(cases_mid, label=nothing, xlabel="Days since July 1st", ylabel="Number of cases")
plot!(cases_upper, label=nothing, c="transparent", fill=cases_lower, fillcolor="orange", fillalpha=0.4)

image

...which looks reasonable. If I use predict:

model_test = seir(x_0, Matrix{Union{Float64,Missing}}(missing,size(data_normalized)...); datavars=[6,7], tspan=(0.0,size(data,1)))
pred_chain = predict(model_test, chain)
preds = reshape(Array(pred_chain), (size(chain,1),size(data_normalized)...))
pred_cases = preds[:,:,1]
pred_deaths = preds[:,:,2]
cases_lower = mapslices(x -> quantile(x, 0.025), pred_cases; dims=1)[1,:]*N
cases_upper = mapslices(x -> quantile(x, 0.975), pred_cases; dims=1)[1,:]*N
cases_mid = mapslices(x -> quantile(x, 0.5), pred_cases; dims=1)[1,:]*N
deaths_lower = mapslices(x -> quantile(x, 0.025), pred_deaths; dims=1)[1,:]*N
deaths_upper = mapslices(x -> quantile(x, 0.975), pred_deaths; dims=1)[1,:]*N
deaths_mid = mapslices(x -> quantile(x, 0.5), pred_deaths; dims=1)[1,:]*N;

then I get:

image

which looks more like Turing just sampled from the prior/likelihood and didn't run the SEIR model at all...

Maybe I'm just doing something wrong? This is just an example I had on hand. If it's too opaque, let me know and I can go grab the Lotka-Volterra one from the docs.

bgroenks96 commented 3 years ago

Just to show that it's not my post-processing code, but actually the values returned by predict, here's the output of pred_chain:

image

torfjelde commented 3 years ago

Hmm, yeah this is weird. We've run into this issue before but then we fixed it in DynamicPPL, so I'm somewhat confused why it's now back. It might just be that we're not lower-bounding the correct version of DPPL or something. I'll have a look at this again after dinner.

And thank you for the very thorough replies/troubleshooting! Really helpful:)

torfjelde commented 3 years ago

So it turns out it's not a "bug" in Turing, per se. The issue is the way you reconstruct the trajectories from the predicted Chains.

I used the Lotka-Volterra Model example from the DiffEq tutorial (https://github.com/TuringLang/TuringTutorials/blob/master/10_diffeq.ipynb), and "reproduced" the issue. So I have this chain:

Chains MCMC chain (1000×17×3 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1, 2, 3
Samples per chain = 1000
parameters        = α, β, γ, δ, σ
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

           α    1.7502    0.2828     0.0052    0.0504    6.4874    3.8125
           β    1.3580    0.3780     0.0069    0.0689    6.2154    5.6250
           γ    3.1057    0.3732     0.0068    0.0608    7.7069    2.2009
           δ    1.1922    0.3713     0.0068    0.0671    6.3035    4.7589
           σ    1.2600    0.6357     0.0116    0.1173    6.0817    9.5661

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           α    1.4601    1.5396    1.5954    2.0444    2.3049
           β    1.0015    1.0752    1.1334    1.8323    1.9901
           γ    2.6141    2.8291    2.9727    3.4039    3.8948
           δ    0.8454    0.9189    0.9709    1.6097    1.9036
           σ    0.7423    0.7989    0.8389    2.0741    2.3172

and this chain returned from predict:

Chains MCMC chain (1000×202×3 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1, 2, 3
Samples per chain = 1000
parameters        = data[1,1], data[1,2], data[1,3], data[1,4], data[1,5], data[1,6], data[1,7], data[1,8], data[1,9], data[1,10], data[1,11], data[1,12], data[1,13], data[1,14], data[1,15], data[1,16], data[1,17], data[1,18], data[1,19], data[1,20], data[1,21], data[1,22], data[1,23], data[1,24], data[1,25], data[1,26], data[1,27], data[1,28], data[1,29], data[1,30], data[1,31], data[1,32], data[1,33], data[1,34], data[1,35], data[1,36], data[1,37], data[1,38], data[1,39], data[1,40], data[1,41], data[1,42], data[1,43], data[1,44], data[1,45], data[1,46], data[1,47], data[1,48], data[1,49], data[1,50], data[1,51], data[1,52], data[1,53], data[1,54], data[1,55], data[1,56], data[1,57], data[1,58], data[1,59], data[1,60], data[1,61], data[1,62], data[1,63], data[1,64], data[1,65], data[1,66], data[1,67], data[1,68], data[1,69], data[1,70], data[1,71], data[1,72], data[1,73], data[1,74], data[1,75], data[1,76], data[1,77], data[1,78], data[1,79], data[1,80], data[1,81], data[1,82], data[1,83], data[1,84], data[1,85], data[1,86], data[1,87], data[1,88], data[1,89], data[1,90], data[1,91], data[1,92], data[1,93], data[1,94], data[1,95], data[1,96], data[1,97], data[1,98], data[1,99], data[1,100], data[1,101], data[2,1], data[2,2], data[2,3], data[2,4], data[2,5], data[2,6], data[2,7], data[2,8], data[2,9], data[2,10], data[2,11], data[2,12], data[2,13], data[2,14], data[2,15], data[2,16], data[2,17], data[2,18], data[2,19], data[2,20], data[2,21], data[2,22], data[2,23], data[2,24], data[2,25], data[2,26], data[2,27], data[2,28], data[2,29], data[2,30], data[2,31], data[2,32], data[2,33], data[2,34], data[2,35], data[2,36], data[2,37], data[2,38], data[2,39], data[2,40], data[2,41], data[2,42], data[2,43], data[2,44], data[2,45], data[2,46], data[2,47], data[2,48], data[2,49], data[2,50], data[2,51], data[2,52], data[2,53], data[2,54], data[2,55], data[2,56], data[2,57], data[2,58], data[2,59], data[2,60], data[2,61], data[2,62], data[2,63], data[2,64], data[2,65], data[2,66], data[2,67], data[2,68], data[2,69], data[2,70], data[2,71], data[2,72], data[2,73], data[2,74], data[2,75], data[2,76], data[2,77], data[2,78], data[2,79], data[2,80], data[2,81], data[2,82], data[2,83], data[2,84], data[2,85], data[2,86], data[2,87], data[2,88], data[2,89], data[2,90], data[2,91], data[2,92], data[2,93], data[2,94], data[2,95], data[2,96], data[2,97], data[2,98], data[2,99], data[2,100], data[2,101]
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64 

   data[1,1]    0.9618    1.3837     0.0253    0.0232   2979.2866    1.0001
   data[1,2]    1.0479    1.4190     0.0259    0.0206   2572.4246    0.9993
   data[1,3]    1.1143    1.4031     0.0256    0.0268   2859.7547    0.9994
   data[1,4]    1.2308    1.4245     0.0260    0.0216   3250.5532    0.9996
   data[1,5]    1.3571    1.4057     0.0257    0.0320   2765.7726    1.0014
   data[1,6]    1.5369    1.3856     0.0253    0.0243   2896.6665    0.9996
   data[1,7]    1.6964    1.4063     0.0257    0.0266   2958.3855    0.9991
   data[1,8]    1.9617    1.4418     0.0263    0.0300   2895.5749    1.0032
   data[1,9]    2.1746    1.4417     0.0263    0.0261   3093.7371    0.9995
  data[1,10]    2.5078    1.4531     0.0265    0.0256   2762.0360    1.0002
  data[1,11]    2.8704    1.4542     0.0265    0.0304   2754.0771    1.0029
  data[1,12]    3.2314    1.4000     0.0256    0.0302   2878.4954    1.0017
  data[1,13]    3.6497    1.4432     0.0263    0.0317   3150.8124    1.0035
  data[1,14]    3.9541    1.4485     0.0264    0.0334   2561.9226    1.0054
  data[1,15]    4.3780    1.4690     0.0268    0.0705    136.1604    1.0370
  data[1,16]    4.6045    1.6718     0.0305    0.1544     23.6200    1.1744
  data[1,17]    4.8337    1.9946     0.0364    0.2446     13.1787    1.3798
  data[1,18]    5.0234    2.3339     0.0426    0.3447      9.1446    1.7371
  data[1,19]    5.2197    2.7004     0.0493    0.4252      8.1008    2.0079
  data[1,20]    5.2321    2.9892     0.0546    0.4873      7.5920    2.2256
  data[1,21]    5.1108    3.0526     0.0557    0.5028      7.4459    2.3274
  data[1,22]    4.6500    2.8477     0.0520    0.4608      7.6882    2.1735
  data[1,23]    3.9549    2.4854     0.0454    0.3786      8.5979    1.8555
      ⋮           ⋮         ⋮         ⋮          ⋮          ⋮          ⋮
                                                             179 rows omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

   data[1,1]   -2.0858    0.2854    0.9968    1.6506    3.8896
   data[1,2]   -1.9736    0.3404    1.0566    1.7578    4.1194
   data[1,3]   -2.0937    0.4180    1.1202    1.8241    4.0490
   data[1,4]   -1.8285    0.5285    1.2156    1.9113    4.4069
   data[1,5]   -1.7131    0.6559    1.3779    2.0828    4.4362
   data[1,6]   -1.5760    0.8314    1.5295    2.2477    4.5621
   data[1,7]   -1.2695    1.0010    1.6866    2.3985    4.8273
   data[1,8]   -1.1320    1.2362    1.9626    2.6368    5.1356
   data[1,9]   -0.8101    1.4318    2.1641    2.9058    5.3575
  data[1,10]   -0.6652    1.7555    2.4671    3.2163    5.8269
  data[1,11]   -0.2795    2.1189    2.8538    3.5823    6.2452
  data[1,12]    0.2534    2.5006    3.2040    3.9085    6.4196
  data[1,13]    0.6215    2.8904    3.6006    4.3557    7.0326
  data[1,14]    0.6667    3.2576    4.0026    4.7282    7.0018
  data[1,15]    0.7178    3.7362    4.5004    5.2242    7.0474
  data[1,16]    0.1213    3.9832    4.8968    5.6332    7.1415
  data[1,17]   -0.3515    4.1338    5.3484    6.1485    7.3915
  data[1,18]   -0.7110    3.8251    5.8347    6.6163    7.8908
  data[1,19]   -1.0973    3.3998    6.2834    7.1162    8.4171
  data[1,20]   -1.5547    3.0825    6.5293    7.3731    8.5543
  data[1,21]   -1.6606    2.7883    6.4689    7.2792    8.5541
  data[1,22]   -1.6817    2.5682    5.8590    6.6887    7.8936
  data[1,23]   -2.0757    2.5429    4.8758    5.6853    6.9547
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                179 rows omitted

If I do the same approach reshaping as you do, i.e.

chain_predict_array = reshape(Array(chain_predict), :, size(odedata)...);

I get the following predictions:

image

The correct way of converting the chain is:

# Convert into array with correct shape
syms = reshape(chain_predict.name_map.parameters, size(odedata_missing')...)
tmp = cat([Array(chain_predict[syms[:, i]]) for i = 1:size(syms, 2)]...; dims=3)
chain_predict_array = permutedims(tmp, (1, 3, 2)) # (num_samples, num_times, dim)

which gives me the following predictions:

image

Hopefully this makes more sense:)

bgroenks96 commented 3 years ago

@torfjelde Thanks for the example!

Perhaps we could add a built-in function for post-processing the chain into an array of predictions? This seems like the most common and natural use case of predict, so I don't think it's very user friendly to require this tricky reshaping!

I also initially suspected that maybe this was the problem. But if you look at the chain statistics in the post above, they also don't make sense. The number of cases, according to the SEIR model, cannot be negative. So the fact that many of the predictive samples have negative means and quantiles around zero is highly suspect.

After digging a bit more into my SEIR COVID-19 example (this isn't my target use-case by the way, I was just playing with Turing/DifferentialEquations.jl), it seems there is also a problem with the likelihood.

The data (y) parameter here is a population proportion. Thus, even a small amount of observation variance (e.g. sigma = 1.0E-3) causes huge fluctuations in the rescaled output when the population N is large, as is the case with entire countries!

Using very small variances on the Normal distribution of < 1.0E-5 seems to cause numerical errors.

This isn't really an issue I've run into before, but I suppose the solution is to use a more suitable likelihood. Maybe a Beta distribution?

I would appreciate your insight as this would help me verify that Turing is indeed working correctly on my example!

torfjelde commented 3 years ago

Perhaps we could add a built-in function for post-processing the chain into an array of predictions? This seems like the most common and natural use case of predict, so I don't think it's very user friendly to require this tricky reshaping!

I think "most common" is maybe not quite true outside of "time-series" like this, e.g. in most cases my data[i, j] passed to predict will only take i = 1 and j = 1:2 rather than i = 1:num_obs so as to produce a single prediction for each sample in chain. But yeah, for time-series (and other similar use-cases) it's indeed very annoying :confused: Unfortunately it's very difficult to do correctly + we sort of want MCMCChains to be agnostic to such, as it makes it much more versatile.

WIth that being said, I 100% agree that we at least should provide functionality for converting these "simple" scenarios like the one you're referring to. I think the best approach as of right now is the following.

First we reshape names(chain) as we want the resulting samples to be, e.g. in my case I have

names(chain_predict)
202-element Array{Symbol,1}:
 Symbol("data[1,1]")
 Symbol("data[1,2]")
 Symbol("data[1,3]")
 Symbol("data[1,4]")
 Symbol("data[1,5]")
 Symbol("data[1,6]")
 Symbol("data[1,7]")
 Symbol("data[1,8]")
 Symbol("data[1,9]")
 Symbol("data[1,10]")
 Symbol("data[1,11]")
 Symbol("data[1,12]")
 Symbol("data[1,13]")
 ⋮
 Symbol("data[2,90]")
 Symbol("data[2,91]")
 Symbol("data[2,92]")
 Symbol("data[2,93]")
 Symbol("data[2,94]")
 Symbol("data[2,95]")
 Symbol("data[2,96]")
 Symbol("data[2,97]")
 Symbol("data[2,98]")
 Symbol("data[2,99]")
 Symbol("data[2,100]")
 Symbol("data[2,101]")

So I do:

syms = reshape(names(chain_predict), :, 2)

to get

101×2 Array{Symbol,2}:
 Symbol("data[1,1]")    Symbol("data[2,1]")
 Symbol("data[1,2]")    Symbol("data[2,2]")
 Symbol("data[1,3]")    Symbol("data[2,3]")
 Symbol("data[1,4]")    Symbol("data[2,4]")
 Symbol("data[1,5]")    Symbol("data[2,5]")
 Symbol("data[1,6]")    Symbol("data[2,6]")
 Symbol("data[1,7]")    Symbol("data[2,7]")
 Symbol("data[1,8]")    Symbol("data[2,8]")
 Symbol("data[1,9]")    Symbol("data[2,9]")
 Symbol("data[1,10]")   Symbol("data[2,10]")
 Symbol("data[1,11]")   Symbol("data[2,11]")
 Symbol("data[1,12]")   Symbol("data[2,12]")
 Symbol("data[1,13]")   Symbol("data[2,13]")
 ⋮                      
 Symbol("data[1,90]")   Symbol("data[2,90]")
 Symbol("data[1,91]")   Symbol("data[2,91]")
 Symbol("data[1,92]")   Symbol("data[2,92]")
 Symbol("data[1,93]")   Symbol("data[2,93]")
 Symbol("data[1,94]")   Symbol("data[2,94]")
 Symbol("data[1,95]")   Symbol("data[2,95]")
 Symbol("data[1,96]")   Symbol("data[2,96]")
 Symbol("data[1,97]")   Symbol("data[2,97]")
 Symbol("data[1,98]")   Symbol("data[2,98]")
 Symbol("data[1,99]")   Symbol("data[2,99]")
 Symbol("data[1,100]")  Symbol("data[2,100]")
 Symbol("data[1,101]")  Symbol("data[2,101]")

Combine this with permutedims(syms, (2, 1)) to get the wanted shape of (2, 101). Once we have this, we add the following utility methods:

 # This makes it so that `AxisArrays.jl` now respects the ordering of
 # the symbols when indexing, e.g. `A[[:a, :b]]` and `A[[:b, :a]]` will
 # return the reversed ordering + now stuff like `A[[:a, :a]]` will also
 # have the expected behavior.
 function AxisArrays.axisindexes(::Type{AxisArrays.Categorical}, ax::AbstractVector, idx::AbstractVector)
     # res = findall(in(idx), ax) # <= original impl
     res = mapreduce(vcat, idx) do i
         findfirst(isequal(i), ax)
     end
     length(res) == length(idx) || throw(ArgumentError("index $(setdiff(idx,ax)) not found"))
     res
 end

 # Essentially just collapses the `syms` and then reshapes.
 # The ordering of `syms` is now preserved.
 function Base.Array(
     chains::MCMCChains.Chains,
     syms::AbstractArray{Symbol},
     args...;
     kwargs...
 )
     # HACK> Index into `AxisArray` directly rather than chain because
     # chain will not respect ordering of indices like `AxisArray` does.
     a = Array(chains.value[:, vec(syms), :], args...; kwargs...)
     return reshape(a, size(a, 1), size(syms)..., size(a, 3))
 end

Equipped with this, we can do the following in my example from above:

 A = Array(chain_predict, permutedims(syms, (2, 1)))

 let chain_idx = 1, sample_idx = 1, i = 1, j = 3, sym = Symbol("data[$i,$j]")
    (
        A[chain_idx, i, j, sample_idx], 
        chain_predict[chain_idx, sym, sample_idx],
        sym
    )
end

resulting in

(-0.4236570209344528, -0.4236570209344528, Symbol("data[1,3]"))

as wanted! :tada:

Btw, I'll make a PR for AxisArrays.jl to add the change above in the package as this seems like it would be a useful feature to have. Might be a reason why they haven't done it though since they seem aware that this would be a nice feature, e.g. https://github.com/JuliaArrays/AxisArrays.jl/blob/9b91d546b28d96cd980e0a86d9c860c3689881d7/src/indexing.jl#L140. There's a def a performance implication, but unclear to me if that's really every going to be a big bottleneck.

torfjelde commented 3 years ago

I also initially suspected that maybe this was the problem. But if you look at the chain statistics in the post above, they also don't make sense. The number of cases, according to the SEIR model, cannot be negative. So the fact that many of the predictive samples have negative means and quantiles around zero is highly suspect.

If I understand you correctly, you're essentially questioning whether or not a Normal likelihood is correct for these problems where the observations are actually bounded, right?

Let me preface by saying that @yebai will likely have a much better answer this question, so hopefully he can chime in + correct me.

I think this very much depends on the scenario. Say if you have a bunch of measurement observations, all of which are far from the boundary of the domain, e.g. the population proportions never get close to 0 or 1 but stay somewhere near 0.5, then a Normal likelihood with a sufficient small variance is likely to provide a good approximation of the underlying noise model (in the large data regime). But if you don't have a lot of measurements + it makes sense for the random variables to take values near the boundaries, then this causes issues (Which is made clear in the predictions above where you observed negative values for something that shouldn't be).

So for the question of what to when the Normal doesn't make sense I would say is highly problem-dependent, and ideally chosen by someone who has intricate knowledge about how the observations were gathered. In the case of the SEIR model with it's data, I'm not entirely certain what to suggest, unfortunately :confused:

For practical purposes when making predictions if the inference looks good (despite the model misspecification), you could do something like:

  1. Use TruncatedNormal which depends on how far the mean-value is from the boundary to ensure the constraints are respected.
  2. In the case where there is a relationship between the different variables, e.g. in SIR model you want the variables to sum to 1, you could clamp and normalize.
  3. Just look at the mean, and ignore the noise, e.g. as is done in the DiffEq + Turing.jl tutorial: https://github.com/TuringLang/TuringTutorials/blob/master/10_diffeq.ipynb. Here they don't use predict but instead just inspect the sampled solutions (from the approximate posterior/chain) to the system rather than the actual sampled observations.
bgroenks96 commented 3 years ago

Thanks for the detailed response, @torfjelde !

I think your AxisArrays solution is reasonable.

If I understand you correctly, you're essentially questioning whether or not a Normal likelihood is correct for these problems where the observations are actually bounded, right?

Yes, that is part of the problem. I suppose TruncatedNormal would solve that. The main issue is with numerical issues for small population proportions.

So yes, the state variables must sum to 1 in an SEIR model, and they are normalized by some constant N being the number of individuals in the population.

The problem is that, for large populations, N is on the order of tens or hundreds of millions. Thus, the state variables E and I are often very, very small, i.e. 1.0E-3 or less. Thus, realistic variance in the cumulative number of cases would be something on the order of 1.0E-5 or 1.0E-6. This causes numerical instability in the MCMC sampler and sample fails with:

┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq /home/brian/.julia/packages/OrdinaryDiffEq/OK16j/src/solve.jl:482
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ DiffEqBase /home/brian/.julia/packages/DiffEqBase/cuMMc/src/integrator_interface.jl:322

followed by a BoundsError.

Ignoring the variance in the predictions would be OK for estimating the mean, but kind of nullifies part of the point of Bayesian parameter estimation, which is to push epistemic and aleatoric uncertainty forward to the predictions to quantify uncertainty in the predictions.

devmotion commented 3 years ago

In the case where there is a relationship between the different variables, e.g. in SIR model you want the variables to sum to 1, you could clamp and normalize.

In my opinion, ideally this would be encoded in the model though. For instance, one could model observations with a Dirichlet distribution whose mean is the ODE solution and that could be more or less concentrated around this mean, depending on an inferred (or user-provided) scaling parameter. I am not familiar with the literature about inference of SIR models and their variants but I would assume that someone has tried and used something similar.

bgroenks96 commented 3 years ago

Good point @devmotion , the Dirchlet could work. I'll look into that.

torfjelde commented 3 years ago

Thus, the state variables E and I are often very, very small, i.e. 1.0E-3 or less.

I guess to fix this issue you could potentially scale with the population number (or I suppose any other constant) and observe this instead. This should work since the reason for requiring such a small variance is that we're on the boundary region rather than the actual noise having such a low variance. But yeah, this is getting real hacky. And adding to the issue of using a Normal likelihood in the SEIR model: if you're observeing all (S, E, I, R), you're essentially observing one of the variables twice due to the fact that once you know (S, E, I) then R is fully determined.

In my opinion, ideally this would be encoded in the model though.

Definitively agree with this, but less clear to me if a Dirichlet likelihood "makes sense" with how it allocates probability mass.

Another option is to use a NegativeBinomial with the ODE solution as a mean (https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html#sampling-distribution) and observe the actual counts rather than the proportions. Note that here they're using a different parameterization of NegativeBinomial than the one available in Distributions.jl (example impl: https://github.com/cambridge-mlg/Covid19/blob/3b1644701ef32063a65fbbc72332ba0eaa22f82b/src/utils.jl#L3-L39). Here I believe it would make the most sense to observe only 3 of the 4 variables in the SEIR model due to reason over-specification I mentioned above.

andreaskoher commented 3 years ago

Though the discussion is somewhat old already, I would add that the SEIR model in your ODE solver is already defined such that S+E+I+R=N, where N is the population size. For the observations we don't usually need to worry about this conservation requirement, mostly because we are not able to observe all compartments at once anyway. In most of the studies that I am aware of, the model is only informed by the case and death counts. However, the latter is not part of your model yet and to the best of my knowledge, most studies use a negative Binomial as mentioned correctly by Tor ( see for example Gibson et al., Flaxman et al. and the R-package epidemia ).

PS. maybe we could write up a simple tutorial similar the one in Stan as mentioned by Tor - I really like it. PPS. There is also a pretty cool implementation of the Flaxman et al. model in Turing: Covid19. This is a great starting point even though the model is discrete in time.

torfjelde commented 3 years ago

PS. maybe we could write up a simple tutorial similar the one in Stan as mentioned by Tor - I really like it.

A tutorial would be dope! We're currently rewamping the tutorial-system for Turing.jl (TuringTutorials.jl), and it's getting real close to be done. Once that's gone through, we should def look into making a slightly more detailed tutorial for these sorts of problems.

andreaskoher commented 3 years ago

Really nice! Building on your Package Covid19, we have recently implemented a very similar model by Unwin et al. and explore applications for Denmark. So let me know if I can contribute anything

bgroenks96 commented 3 years ago

Thanks @andreaskoher and @torfjelde for the helpful comments. I'm totally not an epidemiologist, and I was just using this as a (somewhat) gentle way of introducing myself to Turing + ODEs. It interested me more than Lotka-Volterra ;)