TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Very high initial values for standard deviation #2085

Closed tiemvanderdeure closed 1 year ago

tiemvanderdeure commented 1 year ago

When fitting hierarchical models in Turing, I find that the initial values for standard deviations in the chain are sometimes extremely high, which can give chains with very wrong estimates for many 100s of iterations.

The example below is a model to estimate the rate r at which y approaches 1. r and x are always positive, so y is always between 0 and 1.

In this example, the highest initial value for error I get from 30 chains is 6.6, which is 13 standard deviations away from the prior, but I've encountered way more egregious examples with more complex models.

If I set initial values close to the actual values, I get a pretty healthy chain and correct estimates.

Is there just an issue with my model, or is something else going on here?

using Turing, Distributions, Random, StatsPlots
Random.seed!(1234)

N_individuals = 10
N_populations = 2
N_observations = 200

id_pop = rand(1:N_populations, N_individuals) # assign a population to each individual
id = rand(1:N_individuals, N_observations) # assign an individual to each observation

# Hyperparameters
r_mu = 1.0
r_std_pop = r_std_indiv = 0.1
error = 0.01

# Assign values for a and b to populations
r_pop = rand(Normal(r_mu, r_std_pop), N_populations)
# Assign values for a and b to individuals
r = rand(MvNormal(r_pop[id_pop], r_std_indiv))

# Simulate observations
x = rand(Uniform(0, 2), N_observations)
y = 1.0 .- exp.(-r[id] .* x) .+ rand(Normal(0, error), N_observations)

@model function multilevel_model(id_pop, id, x, y, N_p)
    # Hyperpriors
    r_mu ~ Normal(1.0, 0.1)
    r_std_pop ~ truncated(Normal(0, 0.5); lower = 0.0)
    r_std_indiv ~ truncated(Normal(0, 0.5); lower = 0.0)
    error ~ truncated(Normal(0, 0.5); lower = 0.0)    

    # Population level
    r_pop ~ filldist(Normal(r_mu, r_std_pop), N_p)

    # Individual level
    r ~ arraydist(truncated.(Normal.(r_pop[id_pop], r_std_indiv); lower = 0.0))   

    # Sampling
    y ~ MvNormal(1.0 .- exp.(-r[id] .* x), error)
end

m = multilevel_model(id_pop, id, x, y, N_populations)
chains = [sample(m, HMC(0.001, 10), 100) for i in 1:30]
highest_error = findmax([maximum(c[:error]) for c in chains])

c = chains[highest_error[2]]

r_init = Array(group(c, :r))[1,:,1]

# Mean Squared Error
mean(((1.0 .- exp.(-r_init[id] .* x)) .- y).^2)
harisorgn commented 1 year ago

Is there just an issue with my model, or is something else going on here?

Hey there 👋 , something that you could do to (most likely) improve your model is to use the non-centered parameterisation for the parameters inside the hierachy to avoid funnel-shaped pathologies. Check this case study for a more thorough investigation on the issue : https://betanalpha.github.io/assets/case_studies/hierarchical_modeling.html

The implementation is in Stan, but the theory is the same.

So I would try something like this :

@model function ...
...
# Population level
ε_pop ~ filldist(Normal(0,1), N_p)
r_pop = @. r_mu + ε_pop * r_std_pop
...
# Individual level
# I think you might still get funnels with the truncated distribution because it is a truncated Normal
# and the relationship between its std parameter and the sampled values is the same,
# so better change this one too.
ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id)) 
r = @. r_pop[id_pop] + ε_indiv * r_std_indiv
...
end

If you use the non-centered parameterisation and that helps you avoid funnels then the NUTS sampler would also be able to sample your model more effectively (chain health + speed). The annoying thing with this parameterisation is that r_pop and r are not parameters in your chain object. But you can always calculate them for each sample after you are done, by adding a return at the end of your model and using generated_quantities.

tiemvanderdeure commented 1 year ago

Thanks for the link, I guess you're right that a non-centered parameterisation would be preferred in this case.

As a side note, I think a non-centered distribution would be a little more complicated, since this line ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id)) says all r must be above average, not above 0. Or am I missing something?

When I try it, a non-centered model still produces similarly high values for error, so I assume something else must be causing this. What happens when you run it?

harisorgn commented 1 year ago

As a side note, I think a non-centered distribution would be a little more complicated, since this line ε_indiv ~ filldist(truncated(Normal(0,1); lower = 0.0), length(id)) says all r must be above average, not above 0. Or am I missing something?

You are right, sorry, I rushed it.

When I try it, a non-centered model still produces similarly high values for error, so I assume something else must be causing this. What happens when you run it?

Could be, my suggestion is not directly addressing the error issue, it was more of a general comment for the other standard deviation parameters. Sorry, should have specified. I'll run it asap and see.

harisorgn commented 1 year ago

So this model for me runs in 2.5-4.1 seconds with accurate estimates of error (and the other parameters) :

@model function multilevel_model(id_pop, id, x, y, N_p)
    # Hyperpriors
    r_mu ~ Normal(1.0, 0.1)
    r_std_pop ~ truncated(Normal(0, 0.5); lower = 0.0)
    r_std_indiv ~ truncated(Normal(0, 0.5); lower = 0.0)
    error ~ truncated(Normal(0, 0.5); lower = 0.0)    

    # Population level
    ε_pop ~ filldist(Normal(0,1), N_p)
    r_pop = @. r_mu + ε_pop * r_std_pop

    # Individual level
    r ~ arraydist(truncated.(Normal.(r_pop[id_pop], r_std_indiv); lower = 0.0))   

    # Sampling
    y ~ MvNormal(1.0 .- exp.(-r[id] .* x), error)
end

m = multilevel_model(id_pop, id, x, y, N_populations)
chains = sample(m, NUTS(), 2000)

I didn't reparameterise the individual level. Changing back to a centered parameterisation on the population level will also give you good estimates of error but the chain doesn't look as healthy in the other standard deviation parameters where funneling can happen. This depends on how strong the likelihood term is and with real data I would expect it to be weaker, so non-centered would be safer.

But the main thing that I changed here was using NUTS sampler, so I don't have to tune number and size of leapfrog steps and use a more automated heuristic to sample.

EDIT: Forgot the output :

Info: Found initial step size
└   ϵ = 0.2
Chains MCMC chain (2000×28×1 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 1
Samples per chain = 2000
Wall duration     = 3.99 seconds
Compute duration  = 3.99 seconds
parameters        = r_mu, r_std_pop, r_std_indiv, error, ε_pop[1], ε_pop[2], r[1], r[2], r[3], r[4], r[5], r[6], r[7], r[8], r[9], r[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
   parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec 
       Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

         r_mu    0.9252    0.0730    0.0050    263.0505    397.5729    1.0102       65.9274
    r_std_pop    0.1640    0.1784    0.0252    114.3925     36.0098    1.0166       28.6698
  r_std_indiv    0.1369    0.0443    0.0018    899.7832    792.7033    1.0081      225.5096
        error    0.0099    0.0006    0.0000    186.6455     57.1569    1.0039       46.7783
     ε_pop[1]   -0.2700    0.7084    0.0208   1154.7310    913.0814    1.0066      289.4063
     ε_pop[2]   -0.2498    0.6915    0.0211   1041.6230    922.7823    1.0047      261.0584
         r[1]    0.8389    0.0063    0.0002   1157.0100    839.1514    1.0002      289.9774
         r[2]    0.7822    0.0052    0.0001   1753.4311   1119.7007    0.9995      439.4564
         r[3]    1.0471    0.0099    0.0003    852.1786   1040.6588    1.0007      213.5786
         r[4]    0.9027    0.0081    0.0005    291.4681    177.5084    1.0042       73.0496
         r[5]    0.7816    0.0055    0.0002   1171.1337   1228.7370    1.0050      293.5172
         r[6]    0.7668    0.0046    0.0001   1615.3746   1324.3256    1.0058      404.8558
       ⋮           ⋮         ⋮         ⋮          ⋮           ⋮          ⋮           ⋮
                                                                               4 rows omitted
tiemvanderdeure commented 1 year ago

It seems this issue is related to choosing bad step sizes with the HMC sampler. I also don't get it with the NUTS sampler or when sampling with HMC(0.015, 10). But this problem pops up when using either HMC(0.02, 10) or HMC(0.001, 10). I don't really understand the HMC sampler well enough to explain why this happens, but I guess it's not really a bug then.

harisorgn commented 1 year ago

It doesn't look like a bug to me. Both step size and number of steps are essential parameters for HMC sampling and, depending on the geometry of the posterior, inference results can be sensitive to changing them. That's why NUTS is more recommended; it's a sampler from the same family, building upon HMC.

torfjelde commented 1 year ago

In this example, the highest initial value for error I get from 30 chains is 6.6, which is 13 standard deviations away from the prior

When using HMC, we don't use the prior for initialization, but instead we sample from Uniform(-2, 2) in the unconstrained space, which is a very common approach (e.g. https://mc-stan.org/docs/reference-manual/initialization.html#random-initial-values). This means that getting something like 6.6 or higher for a positively constrained variable can occur with much greater probability than under the prior:

1 - cdf(Uniform(-2, 2), log(6.6))

is roughly ~0.03, i.e. ~3% probability of this occuring.

For reference, a quick way to initialize your sampler using the prior is to just do:

rand(Vector, m)

and pass this to init_params in the sample call, i.e

sample(m, HMC(0.001, 10), 100, init_params = rand(Vector, m))

As a side-note, if you're wondering if there's something strange going on with the initial values, one immediate thing you can do to inspect things a bit more is to manually step through the sampling procedure.

rng = Random.Xoshiro(42)
# We have to wrap the sampler in `Sampler` because that's what the `step` function expects.
spl = DynamicPPL.Sampler(HMC(0.001, 10))

# Take the initial step.
transition, state = AbstractMCMC.step(rng, m, spl);
# Extract the parameters.
DynamicPPL.OrderedDict(Turing.Inference.getparams(m, transition))

If I run this, I get the following output

OrderedCollections.OrderedDict{AbstractPPL.VarName, Float64} with 16 entries:
  r_mu        => 1.07888
  r_std_pop   => 0.830055
  r_std_indiv => 0.913634
  error       => 2.2527
  r_pop[1]    => 0.500225
  r_pop[2]    => 1.02133
  r[1]        => 0.647186
  r[2]        => 3.36457
  r[3]        => 5.55017
  r[4]        => 3.80987
  r[5]        => 1.9428
  r[6]        => 0.234109
  r[7]        => 0.384531
  r[8]        => 1.97316
  r[9]        => 2.34884
  r[10]       => 0.263924

From here you can also manually step through the sampling process:

# Now we just provide the previous state
transition_next, state_next = AbstractMCMC.step(rng, m, spl, state)
torfjelde commented 1 year ago

But yes, as @harisorgn says, always use NUTS:) Alternatively, you can also use HMCDA, which is HMC but with step-size adaptation

tiemvanderdeure commented 1 year ago

Thank you for such an elaborate explanation @torfjelde. There are many good reasons to use NUTS over HMC, it seems.