TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Parameter values initialization does not speed up NUTS #1509

Closed ClaudMor closed 2 years ago

ClaudMor commented 3 years ago

Hello,

This question is related to this issue, but since it's slightly different I opened another. If it is wrong please let me know.

Thanks to your PR, now init_params seems to work. The problem is that if I set init_params to reasonable values, NUTS does not seem to speed up (I'm not even sure if it should).

As an example, I will use a model found in the tutorials

Let's first find reasonable parameters values, by running NUTS once:


# Imports and reproducibility

using Turing, DifferentialEquations, Distributions

using MCMCChains, Plots, StatsPlots

using Random
Random.seed!(14);

# Model definition

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, γ, δ  = p
  du[1] = (α - β*y)x # dx =
  du[2] = (δ*x - γ)y # dy = 
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob1,Tsit5())

# generate synthetic data

sol1 = solve(prob1,Tsit5(),saveat=0.1)
odedata = Array(sol1) + 0.8 * randn(size(Array(sol1)))
plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata')

# define priors and Turing model

priors = [truncated(Normal(1.5,0.5),0.5,2.5), truncated(Normal(1.2,0.5),0,2), truncated(Normal(3.0,0.5),1,4),truncated(Normal(1.0,0.5),0,2)]

@model function fitlv2(data, prob1, priors) 
    σ ~ InverseGamma(2, 3)
    p ~ arraydist(priors)  #priors

    prob = remake(prob1, p=p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ)
    end
end

model1= fitlv2(odedata, prob1, priors)

Next we sample one chain of 10'000:

chain1 = @time sample(model1, NUTS(0.45), 10000, progress = true) 
 Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\claud\.julia\packages\AdvancedHMC\MIxdK\src\hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\claud\.julia\packages\AdvancedHMC\MIxdK\src\hamiltonian.jl:47
┌ Info: Found initial step size
│   ϵ = 0.00625
└ @ Turing.Inference C:\Users\claud\.julia\packages\Turing\O1Pn0\src\inference\hmc.jl:195
Sampling:   1%|█                                        |  ETA: 0:00:45┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
 @ AdvancedHMC C:\Users\claud\.julia\packages\AdvancedHMC\MIxdK\src\hamiltonian.jl:47    
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\claud\.julia\packages\AdvancedHMC\MIxdK\src\hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).    
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false) <--- I GET A LOT OF THESE SOMETIMES, SO I'LL OMIT THE REMAINING PRINT
Chains MCMC chain (10000×17×1 Array{Float64,3}):

Iterations        = 1:10000
Thinning interval = 1
Chains            = 1
Samples per chain = 10000
parameters        = p[1], p[2], p[3], p[4], σ
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

        p[1]    1.5593    0.0519     0.0005    0.0034   157.5147    1.0125
        p[2]    1.0944    0.0526     0.0005    0.0027   260.4492    1.0069
        p[3]    2.8713    0.1367     0.0014    0.0091   149.1489    1.0130
        p[4]    0.9356    0.0491     0.0005    0.0032   153.1249    1.0133
           σ    0.8191    0.0390     0.0004    0.0029    90.4143    1.0031

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        p[1]    1.4592    1.5257    1.5597    1.5916    1.6629
        p[2]    0.9947    1.0588    1.0934    1.1283    1.2040
        p[3]    2.6195    2.7795    2.8644    2.9541    3.1642
        p[4]    0.8455    0.9029    0.9330    0.9658    1.0366
           σ    0.7521    0.7923    0.8169    0.8423    0.9039

And It took 29.285691 seconds (226.37 M allocations: 17.156 GiB, 8.58% gc time)

You can plot the chain

plot(chain1)

If I initialize the same sampling with the parameter values found in this run, i get

# parameter values initialization
const initials = [0.8191, [1.5593, 1.0944, 2.8713, 0.9356 ]]

# sample with initial parameter values
chain2 = @time sample(model1, NUTS(0.45), 10000,  progress = true, init_params = initials) 
┌ Info: Found initial step size
│   ϵ = 0.025
└ @ Turing.Inference C:\Users\claud\.julia\packages\Turing\O1Pn0\src\inference\hmc.jl:195
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:31
 Chains MCMC chain (10000×17×1 Array{Float64,3}):

Iterations        = 1:10000
Thinning interval = 1
Chains            = 1
Samples per chain = 10000
parameters        = p[1], p[2], p[3], p[4], σ
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

        p[1]    1.5557    0.0497     0.0005    0.0029   199.8191    1.0044
        p[2]    1.0904    0.0524     0.0005    0.0028   303.4473    1.0039
        p[3]    2.8795    0.1307     0.0013    0.0076   198.7363    1.0050
        p[4]    0.9385    0.0466     0.0005    0.0027   198.9049    1.0042
           σ    0.8115    0.0416     0.0004    0.0023   254.6818    1.0009

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        p[1]    1.4704    1.5191    1.5521    1.5867    1.6649
        p[2]    1.0020    1.0526    1.0865    1.1229    1.2036
        p[3]    2.6126    2.7898    2.8834    2.9747    3.1233
        p[4]    0.8425    0.9077    0.9393    0.9712    1.0271
           σ    0.7372    0.7834    0.8096    0.8380    0.8996

And it took 31.805427 seconds (268.14 M allocations: 20.313 GiB, 8.99% gc time)

I observe the same behaviour with more complex underlying DifferentialEquations.jl models ( which require more time to calibrate), but same Turing.jl model .

Should I expect that parameter initialization to actually speed up the process?

Side question: what is the number of model parameters that should induce one to use ADVI instead of NUTS?

Thanks very much for your attention

devmotion commented 3 years ago

I would not expect the initialization to have an impact on the computation time - the same number of calculations have to be performed in both cases. However, I would assume that the step size and the number of accepted steps could/should be affected by a different initialization, which could result in e.g. better mixing. The higher ESS values in the run with a custom initialization support this intuition.

devmotion commented 3 years ago

BTW in some sense the improved mixing leads to a relative speed-up - to achieve the same effective sample size with a random initialization you would need a longer chain which of course would take more time to sample.

ClaudMor commented 3 years ago

Thank you very much @devmotion .

I know it's a bit off topic, but would you have any advice on what is the approximate minimum number of model parameters that should induce one to use ADVI instead of NUTS?

Here a 12 is cited, but I'm not sure if it refers to the number of parameters or other.

devmotion commented 3 years ago

I don't have any good heuristics here. As far as I can tell, the number 12 in the documentation is completely arbitrary and refers to the number of samples in the Markov chain. The main point there seems to be that the MCMC methods provide exactness guarantees if the number of samples goes to infinity but the number of samples required for a good approximation of the posterior (e.g. in the sense that the estimation of the expectation of some functional is reasonably close to the true value) might be prohibitively large. Therefore sometimes one might prefer an approximate method such as VI over MCMC methods.