TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Wishart priors resulting in `PosDefException: matrix is not positive definite; Cholesky factorization failed` #2188

Open ipozdeev opened 7 months ago

ipozdeev commented 7 months ago

julia v1.10.2 Turing v0.30.7 Distributions v0.25.107

Using Wishart priors results in several functions throwing the above error, which does not make sense to me, e.g. in this MWE maximizing:

using Turing, MCMCChains
using Statistics, LinearAlgebra, PDMats
using Optim

# parameter of the Wishart prior
A = Matrix{Float64}(I, 3, 3);
isposdef(A)  # true
ishermitian(A)  # true

@model function demo(x)
    _A ~ Wishart(5, A);
    _x_mu = sum(_A);
    return x ~ Normal(_x_mu, 1);
end

# condition model on single obs
demo_model = demo(1.0);

map_estimate = optimize(demo_model, MAP());  # error
chain = sample(model, HMC(0.05, 10), 1000);  # error

chain = sample(model, MH(), 1000);  # no error

MAP() throws an error, as does sampling with HMC and NUTS, but not with MH.

yebai commented 5 months ago

Likely related to https://github.com/TuringLang/Bijectors.jl/pull/313

yebai commented 5 months ago

cc @sethaxen @torfjelde

yebai commented 5 months ago

The above example now works with the most recent Bijectors.jl release. A small step size is useful for stable optimisation or HMC sampling:

julia> chain = sample(demo_model, HMC(0.01, 10), 10000); 
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00

I am not sure how to specify step size for optimisation. @mhauru is that possible?

yebai commented 5 months ago

These numerical errors can be explicitly catched and used to inform the inference backend to reject the proposal.


julia> @model function demo(x)
           try 
                 _A ~ Wishart(5, A);
                 _x_mu = sum(_A);
                 x ~ Normal(_x_mu, 1); 
          catch e; 
                 if e isa PosDefException
                        Turing.@addlogprob! -Inf; 
                 end 
          end
       end

julia> chain = sample(demo(1), NUTS(), 1000)
┌ Info: Found initial step size
└   ϵ = 0.0125                                                                                    |  ETA: N/A
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×21×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.19 seconds
Compute duration  = 0.19 seconds
parameters        = _A[1, 1], _A[2, 1], _A[3, 1], _A[1, 2], _A[2, 2], _A[3, 2], _A[1, 3], _A[2, 3], _A[3, 3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

    _A[1, 1]    3.6504    2.3821    0.0935   546.6195   372.3050    1.0003     2907.5507
    _A[2, 1]   -1.4789    1.7435    0.0766   565.4885   588.1340    1.0066     3007.9174
    _A[3, 1]   -1.5504    1.7307    0.1053   254.4582   458.9476    1.0133     1353.5010
    _A[1, 2]   -1.4789    1.7435    0.0766   565.4885   588.1340    1.0066     3007.9174
    _A[2, 2]    3.5547    2.2049    0.0896   549.1068   560.1294    1.0012     2920.7810
    _A[3, 2]   -1.4334    1.6267    0.0781   466.9718   495.1435    0.9991     2483.8925
    _A[1, 3]   -1.5504    1.7307    0.1053   254.4582   458.9476    1.0133     1353.5010
    _A[2, 3]   -1.4334    1.6267    0.0781   466.9718   495.1435    0.9991     2483.8925
    _A[3, 3]    3.5786    2.2005    0.1018   556.3600   493.4034    1.0026     2959.3615

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    _A[1, 1]    0.5643    1.8723    3.1353    4.8230    9.5504
    _A[2, 1]   -5.6961   -2.3031   -1.2180   -0.3144    1.1457
    _A[3, 1]   -5.8425   -2.3965   -1.2686   -0.4025    1.0603
    _A[1, 2]   -5.6961   -2.3031   -1.2180   -0.3144    1.1457
    _A[2, 2]    0.5728    1.9093    3.1736    4.6654    8.8305
    _A[3, 2]   -5.1991   -2.3551   -1.2366   -0.3020    1.2692
    _A[1, 3]   -5.8425   -2.3965   -1.2686   -0.4025    1.0603
    _A[2, 3]   -5.1991   -2.3551   -1.2366   -0.3020    1.2692
    _A[3, 3]    0.7432    1.9730    3.1545    4.6346    9.2545
mhauru commented 5 months ago

I am not sure how to specify step size for optimisation. @mhauru is that possible?

Depends on the optimisation algorithm, but if the algorithm has a notion of a step size, then usually yes. The default algorithm is LBFGS, which first finds a direction to go in and then does a line search along that direction to figure out how far to go, so there isn't a fixed step size, but you can set an initial guess for the step size like this:

optimize(demo_model, MAP(), Optim.LBFGS(;alphaguess=0.01));

That seems to help avoiding the loss of positivity errors in this case.

sethaxen commented 5 months ago

@yebai I don't think anything I did fixing the correlation bijectors would have fixed this. I'm not at a computer right now, but I imagine the problem is similar; the Jacobian computation is going through a cholesky decomposition, which is wasteful and can randomly fail due to floating point errors. The solution is to make the same fix for the covariance bijector.

The other place a PosDefException would be raised randomly is if one used the Wishart matrix as the covariance of an MvNormal. The solution there is the same as LKJ: add a WishartCholesky to Distributions.jl and a VecCovBijector to Bijectors.jl. Same goes for InverseWishart.

yebai commented 5 months ago

Thanks @sethaxen, for the clarification. For now, users can be referred to https://github.com/TuringLang/Turing.jl/issues/2188#issuecomment-2149755180 before numerically more stable alternatives are implemented. We should also update docs to include some guides on how to use try-catch block to handle numerical exceptions.