Closed tiemvanderdeure closed 1 year ago
Is there just an issue with my model, or is something else going on here?
Hey there 👋 , something that you could do to (most likely) improve your model is to use the non-centered parameterisation for the parameters inside the hierachy to avoid funnel-shaped pathologies. Check this case study for a more thorough investigation on the issue : https://betanalpha.github.io/assets/case_studies/hierarchical_modeling.html
The implementation is in Stan, but the theory is the same.
So I would try something like this :
@model function ...
...
# Population level
ε_pop ~ filldist(Normal(0,1), N_p)
r_pop = @. r_mu + ε_pop * r_std_pop
...
# Individual level
# I think you might still get funnels with the truncated distribution because it is a truncated Normal
# and the relationship between its std parameter and the sampled values is the same,
# so better change this one too.
ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id))
r = @. r_pop[id_pop] + ε_indiv * r_std_indiv
...
end
If you use the non-centered parameterisation and that helps you avoid funnels then the NUTS
sampler would also be able to sample your model more effectively (chain health + speed). The annoying thing with this parameterisation is that r_pop
and r
are not parameters in your chain object. But you can always calculate them for each sample after you are done, by adding a return
at the end of your model and using generated_quantities
.
Thanks for the link, I guess you're right that a non-centered parameterisation would be preferred in this case.
As a side note, I think a non-centered distribution would be a little more complicated, since this line ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id))
says all r must be above average, not above 0. Or am I missing something?
When I try it, a non-centered model still produces similarly high values for error
, so I assume something else must be causing this. What happens when you run it?
As a side note, I think a non-centered distribution would be a little more complicated, since this line ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id)) says all r must be above average, not above 0. Or am I missing something?
You are right, sorry, I rushed it.
When I try it, a non-centered model still produces similarly high values for error, so I assume something else must be causing this. What happens when you run it?
Could be, my suggestion is not directly addressing the error
issue, it was more of a general comment for the other standard deviation parameters. Sorry, should have specified. I'll run it asap and see.
So this model for me runs in 2.5-4.1 seconds with accurate estimates of error
(and the other parameters) :
@model function multilevel_model(id_pop, id, x, y, N_p)
# Hyperpriors
r_mu ~ Normal(1.0, 0.1)
r_std_pop ~ truncated(Normal(0, 0.5); lower = 0.0)
r_std_indiv ~ truncated(Normal(0, 0.5); lower = 0.0)
error ~ truncated(Normal(0, 0.5); lower = 0.0)
# Population level
ε_pop ~ filldist(Normal(0,1), N_p)
r_pop = @. r_mu + ε_pop * r_std_pop
# Individual level
r ~ arraydist(truncated.(Normal.(r_pop[id_pop], r_std_indiv); lower = 0.0))
# Sampling
y ~ MvNormal(1.0 .- exp.(-r[id] .* x), error)
end
m = multilevel_model(id_pop, id, x, y, N_populations)
chains = sample(m, NUTS(), 2000)
I didn't reparameterise the individual level. Changing back to a centered parameterisation on the population level will also give you good estimates of error
but the chain doesn't look as healthy in the other standard deviation parameters where funneling can happen. This depends on how strong the likelihood term is and with real data I would expect it to be weaker, so non-centered would be safer.
But the main thing that I changed here was using NUTS
sampler, so I don't have to tune number and size of leapfrog steps and use a more automated heuristic to sample.
EDIT: Forgot the output :
Info: Found initial step size
└ ϵ = 0.2
Chains MCMC chain (2000×28×1 Array{Float64, 3}):
Iterations = 1001:1:3000
Number of chains = 1
Samples per chain = 2000
Wall duration = 3.99 seconds
Compute duration = 3.99 seconds
parameters = r_mu, r_std_pop, r_std_indiv, error, ε_pop[1], ε_pop[2], r[1], r[2], r[3], r[4], r[5], r[6], r[7], r[8], r[9], r[10]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
r_mu 0.9252 0.0730 0.0050 263.0505 397.5729 1.0102 65.9274
r_std_pop 0.1640 0.1784 0.0252 114.3925 36.0098 1.0166 28.6698
r_std_indiv 0.1369 0.0443 0.0018 899.7832 792.7033 1.0081 225.5096
error 0.0099 0.0006 0.0000 186.6455 57.1569 1.0039 46.7783
ε_pop[1] -0.2700 0.7084 0.0208 1154.7310 913.0814 1.0066 289.4063
ε_pop[2] -0.2498 0.6915 0.0211 1041.6230 922.7823 1.0047 261.0584
r[1] 0.8389 0.0063 0.0002 1157.0100 839.1514 1.0002 289.9774
r[2] 0.7822 0.0052 0.0001 1753.4311 1119.7007 0.9995 439.4564
r[3] 1.0471 0.0099 0.0003 852.1786 1040.6588 1.0007 213.5786
r[4] 0.9027 0.0081 0.0005 291.4681 177.5084 1.0042 73.0496
r[5] 0.7816 0.0055 0.0002 1171.1337 1228.7370 1.0050 293.5172
r[6] 0.7668 0.0046 0.0001 1615.3746 1324.3256 1.0058 404.8558
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
4 rows omitted
It seems this issue is related to choosing bad step sizes with the HMC sampler. I also don't get it with the NUTS sampler or when sampling with HMC(0.015, 10)
. But this problem pops up when using either HMC(0.02, 10)
or HMC(0.001, 10)
. I don't really understand the HMC sampler well enough to explain why this happens, but I guess it's not really a bug then.
It doesn't look like a bug to me. Both step size and number of steps are essential parameters for HMC
sampling and, depending on the geometry of the posterior, inference results can be sensitive to changing them. That's why NUTS
is more recommended; it's a sampler from the same family, building upon HMC.
In this example, the highest initial value for error I get from 30 chains is 6.6, which is 13 standard deviations away from the prior
When using HMC, we don't use the prior for initialization, but instead we sample from Uniform(-2, 2)
in the unconstrained space, which is a very common approach (e.g. https://mc-stan.org/docs/reference-manual/initialization.html#random-initial-values). This means that getting something like 6.6
or higher for a positively constrained variable can occur with much greater probability than under the prior:
1 - cdf(Uniform(-2, 2), log(6.6))
is roughly ~0.03, i.e. ~3% probability of this occuring.
For reference, a quick way to initialize your sampler using the prior is to just do:
rand(Vector, m)
and pass this to init_params
in the sample
call, i.e
sample(m, HMC(0.001, 10), 100, init_params = rand(Vector, m))
As a side-note, if you're wondering if there's something strange going on with the initial values, one immediate thing you can do to inspect things a bit more is to manually step through the sampling procedure.
rng = Random.Xoshiro(42)
# We have to wrap the sampler in `Sampler` because that's what the `step` function expects.
spl = DynamicPPL.Sampler(HMC(0.001, 10))
# Take the initial step.
transition, state = AbstractMCMC.step(rng, m, spl);
# Extract the parameters.
DynamicPPL.OrderedDict(Turing.Inference.getparams(m, transition))
If I run this, I get the following output
OrderedCollections.OrderedDict{AbstractPPL.VarName, Float64} with 16 entries:
r_mu => 1.07888
r_std_pop => 0.830055
r_std_indiv => 0.913634
error => 2.2527
r_pop[1] => 0.500225
r_pop[2] => 1.02133
r[1] => 0.647186
r[2] => 3.36457
r[3] => 5.55017
r[4] => 3.80987
r[5] => 1.9428
r[6] => 0.234109
r[7] => 0.384531
r[8] => 1.97316
r[9] => 2.34884
r[10] => 0.263924
From here you can also manually step through the sampling process:
# Now we just provide the previous state
transition_next, state_next = AbstractMCMC.step(rng, m, spl, state)
But yes, as @harisorgn says, always use NUTS
:) Alternatively, you can also use HMCDA
, which is HMC
but with step-size adaptation
Thank you for such an elaborate explanation @torfjelde. There are many good reasons to use NUTS over HMC, it seems.
When fitting hierarchical models in Turing, I find that the initial values for standard deviations in the chain are sometimes extremely high, which can give chains with very wrong estimates for many 100s of iterations.
The example below is a model to estimate the rate
r
at whichy
approaches 1.r
andx
are always positive, soy
is always between 0 and 1.In this example, the highest initial value for
error
I get from 30 chains is 6.6, which is 13 standard deviations away from the prior, but I've encountered way more egregious examples with more complex models.If I set initial values close to the actual values, I get a pretty healthy chain and correct estimates.
Is there just an issue with my model, or is something else going on here?