TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
202 stars 33 forks source link

CorrBijector makes posterior improper #228

Open sethaxen opened 2 years ago

sethaxen commented 2 years ago

An $n \times n$ correlation matrix has ${n \choose 2} = \frac{n (n-1)}{2}$ degrees of freedom. This is the same as the number of elements in a strict upper triangular $n \times n$ matrix. The CorrBijector works by mapping from the correlation matrix first to its unique upper Cholesky factor and then to a strictly upper triangular matrix of unconstrained entries.

The trouble is that in the unconstrained space, we now have $n \times n$ parameters, of which ${n+1 \choose 2} = \frac{n(n+1)}{2}$ have no impact on the log density. These extra parameters have an implicit improper uniform prior on the reals, which makes the posterior distribution in unconstrained space improper. Because these parameters have infinite variance, during adaptation, HMC will learn this, and they will explode in value. I don't know if this will have any negative impact on sampling.

In this demo, we're sampling the uniform distribution on the correlation matrices.

julia> using Turing, Random

julia> @model function model(n, η)
           R ~ LKJ(n, η)
       end;

julia> mod = model(3, 2.0);

julia> Random.seed!(50);

julia> chns = sample(mod, NUTS(0.99), 1_000; save_state=true)
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
Sampling 100%|█████████████████████████████████████████████████████████████████████████| Time: 0:00:11
Chains MCMC chain (1000×21×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 12.02 seconds
Compute duration  = 12.02 seconds
parameters        = R[1,1], R[2,1], R[3,1], R[1,2], R[2,2], R[3,2], R[1,3], R[2,3], R[3,3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

      R[1,1]    1.0000    0.0000     0.0000    0.0000         NaN       NaN           NaN
      R[2,1]   -0.0086    0.4056     0.0128    0.0131   1010.4052    0.9993       84.0673
      R[3,1]   -0.0081    0.4084     0.0129    0.0113   1034.8774    1.0000       86.1035
      R[1,2]   -0.0086    0.4056     0.0128    0.0131   1010.4052    0.9993       84.0673
      R[2,2]    1.0000    0.0000     0.0000    0.0000   1030.1887    0.9990       85.7133
      R[3,2]   -0.0156    0.4045     0.0128    0.0139    849.6362    1.0033       70.6911
      R[1,3]   -0.0081    0.4084     0.0129    0.0113   1034.8774    1.0000       86.1035
      R[2,3]   -0.0156    0.4045     0.0128    0.0139    849.6362    1.0033       70.6911
      R[3,3]    1.0000    0.0000     0.0000    0.0000   1005.9869    0.9990       83.6997

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      R[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[2,1]   -0.7567   -0.3053   -0.0089    0.2914    0.7302
      R[3,1]   -0.7980   -0.2985   -0.0014    0.2902    0.7525
      R[1,2]   -0.7567   -0.3053   -0.0089    0.2914    0.7302
      R[2,2]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[3,2]   -0.7391   -0.3212   -0.0194    0.3046    0.7256
      R[1,3]   -0.7980   -0.2985   -0.0014    0.2902    0.7525
      R[2,3]   -0.7391   -0.3212   -0.0194    0.3046    0.7256
      R[3,3]    1.0000    1.0000    1.0000    1.0000    1.0000

julia> chns.info.samplerstate.hamiltonian.metric.M⁻¹
9-element Vector{Float64}:
 1.8098190471067061e21
 7.440167311295848e19
 2.0026621564801238e21
 0.24698526419696235
 1.8270061628986394e21
 6.521608457727148e20
 0.27103163446341116
 0.43223828394884933
 3.2173213058057444e20

Note the number of parameters. We should have 3*(3-1)/2 = 3 DOFs, but instead we have 3*3. And note that 3*(3+1)/2=6 of the degrees of freedom have adapted variances of ~1e20.

There are several ways to solve this, neither of which seem to be possible in Bijectors right now:

sethaxen commented 2 years ago

I don't know if this will have any negative impact on sampling.

After further thought, yes, this is detrimental to sampling with NUTS, because it makes it very unlikely that a u-turn will be encountered, so the maximum tree depth of 10 is usually saturated:

julia> mean(==(10), chns[:tree_depth])
0.961

So 96% of transitions saturated the tree depth (took 1024 leapfrog steps). Compare with the same model with Stan:

julia> using CmdStan, MCMCChains

julia> code = """
       data {
        int<lower=1> n;
        real<lower=0> eta;
       }
       parameters {
        corr_matrix[n] R;
       }
       model {
         R ~ lkj_corr(eta);
       }
       """;

julia> stanmodel = Stanmodel(name="lkj", model=code, output_format=:mcmcchains);
...

julia> _, chns, _ = stan(stanmodel, Dict("n" => 3, "eta" => 2), "./");
...

Inference for Stan model: lkj_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (0.013, 0.014, 0.014, 0.014) seconds, 0.055 seconds total
Sampling took (0.027, 0.026, 0.029, 0.026) seconds, 0.11 seconds total

                    Mean     MCSE   StdDev     5%       50%   95%    N_Eff  N_Eff/s    R_hat

lp__            -1.7e+00  3.4e-02  1.4e+00   -4.5  -1.3e+00 -0.20  1.7e+03  1.6e+04  1.0e+00
accept_stat__       0.91  1.8e-03     0.12   0.64      0.95   1.0  4.7e+03  4.4e+04  1.0e+00
stepsize__          0.80  5.6e-02    0.080   0.67      0.85  0.88  2.0e+00  1.9e+01  2.8e+13
treedepth__          2.2  9.1e-02     0.53    1.0       2.0   3.0  3.4e+01  3.1e+02  1.0e+00
n_leapfrog__         4.5  3.5e-01      2.0    3.0       3.0   7.0  3.4e+01  3.1e+02  1.0e+00
divergent__         0.00      nan     0.00   0.00      0.00  0.00      nan      nan      nan
energy__             3.2  5.0e-02      1.9   0.87       2.8   6.8  1.5e+03  1.3e+04  1.0e+00

R[1,1]           1.0e+00      nan  6.7e-16    1.0   1.0e+00   1.0      nan      nan      nan
R[1,2]           4.0e-03  7.0e-03  4.1e-01  -0.67  -3.2e-03  0.68  3.5e+03  3.2e+04  1.0e+00
R[1,3]           2.9e-03  6.6e-03  4.1e-01  -0.68   2.2e-03  0.67  3.9e+03  3.6e+04  1.0e+00
R[2,1]           4.0e-03  7.0e-03  4.1e-01  -0.67  -3.2e-03  0.68  3.5e+03  3.2e+04  1.0e+00
R[2,2]           1.0e+00      nan  6.7e-16    1.0   1.0e+00   1.0      nan      nan      nan
R[2,3]          -1.7e-02  7.1e-03  4.1e-01  -0.69  -1.5e-02  0.66  3.3e+03  3.1e+04  1.0e+00
R[3,1]           2.9e-03  6.6e-03  4.1e-01  -0.68   2.2e-03  0.67  3.9e+03  3.6e+04  1.0e+00
R[3,2]          -1.7e-02  7.1e-03  4.1e-01  -0.69  -1.5e-02  0.66  3.3e+03  3.1e+04  1.0e+00
R[3,3]           1.0e+00      nan  6.7e-16    1.0   1.0e+00   1.0      nan      nan      nan

Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at 
convergence, R_hat=1).

julia> mean(==(10), chns[:treedepth__])
0.0

julia> mean(chns[:treedepth__])
2.203

No transitions hit the max tree depth, and on average each transition took only 2 tree doublings (7 leapfrog steps).

sethaxen commented 2 years ago

As demonstrated in https://discourse.julialang.org/t/case-study-speeding-up-a-logistic-regression-with-rhs-prior-turing-vs-numpyro-any-tricks-im-missing/87681/34, this impropriety also introduces post-warmup numerical error.

A quick-and-dirty hack to get this working for Turing users might be to edit https://github.com/TuringLang/Bijectors.jl/blob/b20471252c01dd4832e06aa80045046483f3804e/src/bijectors/corr.jl#L83-L95 to add

    for j in 1:K, i in j:K
        result -= y[i, j]^2 / 2
    end

This puts a standard normal prior on the extra DOFs. The downside is that it technically lies about the logdetjac (but so does the current implementation, which lies about the transform being bijective), and the logdetjac of the inverse function will disagree, but this is a band-aid that will work until Bijectors can support inputs and outputs of different dimensions.

torfjelde commented 1 year ago

Okay, so combining https://github.com/TuringLang/Bijectors.jl/pull/246 and https://github.com/TuringLang/DynamicPPL.jl/pull/462, this now works:

julia> using Turing, Random
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]

julia> @model function model(n, η)
           R ~ LKJ(n, η)
       end;

julia> mod = model(3, 2.0);

julia> Random.seed!(50);

julia> chns = sample(mod, NUTS(0.99), 1_000; save_state=true)
┌ Info: Found initial step size
└   ϵ = 1.6
Sampling 100%|██████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:09
Chains MCMC chain (1000×21×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 11.07 seconds
Compute duration  = 11.07 seconds
parameters        = R[1,1], R[2,1], R[3,1], R[1,2], R[2,2], R[3,2], R[1,3], R[2,3], R[3,3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

      R[1,1]    1.0000    0.0000     0.0000    0.0000        NaN       NaN           NaN
      R[2,1]   -0.0274    0.4129     0.0131    0.0153   773.7963    0.9994       69.8877
      R[3,1]    0.0037    0.4011     0.0127    0.0147   642.2541    0.9992       58.0071
      R[1,2]   -0.0274    0.4129     0.0131    0.0153   773.7963    0.9994       69.8877
      R[2,2]    1.0000    0.0000     0.0000    0.0000   996.8495    0.9990       90.0334
      R[3,2]    0.0077    0.4061     0.0128    0.0157   670.0528    1.0039       60.5178
      R[1,3]    0.0037    0.4011     0.0127    0.0147   642.2541    0.9992       58.0071
      R[2,3]    0.0077    0.4061     0.0128    0.0157   670.0528    1.0039       60.5178
      R[3,3]    1.0000    0.0000     0.0000    0.0000   672.7942    0.9990       60.7654

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      R[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[2,1]   -0.7647   -0.3269   -0.0286    0.2606    0.7454
      R[3,1]   -0.7393   -0.3006   -0.0072    0.3158    0.7480
      R[1,2]   -0.7647   -0.3269   -0.0286    0.2606    0.7454
      R[2,2]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[3,2]   -0.7260   -0.2919    0.0068    0.2974    0.7714
      R[1,3]   -0.7393   -0.3006   -0.0072    0.3158    0.7480
      R[2,3]   -0.7260   -0.2919    0.0068    0.2974    0.7714
      R[3,3]    1.0000    1.0000    1.0000    1.0000    1.0000

Though I'm confused as to why MCMCChains reports a non-NaN ESS for some of the diagonal elements (I've checked the underlying values and they are indeed fixed to 1).

sethaxen commented 1 year ago

Okay, so combining #246 and TuringLang/DynamicPPL.jl#462, this now works:

Awesome! Can you also check the estimated metric?

Though I'm confused as to why MCMCChains reports a non-NaN ESS for some of the diagonal elements (I've checked the underlying values and they are indeed fixed to 1).

If they're even a little offset from 1, then a NaN won't be returned, e.g.

julia> using MCMCDiagnosticTools

julia> x = ones(1_000, 1, 10);

julia> MCMCDiagnosticTools.ess_rhat(x)[1]
10-element Vector{Float64}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN

julia> MCMCDiagnosticTools.ess_rhat(x .+ eps() .* randn.())[1]
10-element Vector{Float64}:
  955.6941303601072
  713.5937319411215
 1092.1539947505667
  903.5876039869821
  925.0133494392358
  894.2614035388227
  793.3188327568754
 1058.2749400295972
  988.8133813971027
 1096.6615583111022

In #246, the final R matrix is generated from the Cholesky factor through matrix multiplication, so the diagonals may not be exactly 1.

An interesting question is what should ess_rhat do when given a matrix of identical numbers. ArviZ returns an ESS equivalent to the sample size. posterior returns an NA. I tend towards returning NaN since if chains are completely stalled, all values will be identical, but we don't want to return a high ESS. But this creates the odd behavior that as the posterior variance goes to 0 we first approach the sample size, then suddenly return NaN.

torfjelde commented 1 year ago

Awesome! Can you also check the estimated metric?

Here it is:

julia> chns.info.samplerstate.hamiltonian.metric.M⁻¹
3-element Vector{Float64}:
 0.25406253280310975
 0.22661105092259304
 0.2953807290308163

The underlying tracked variables are now the subspace btw, if that's what you were wondering.

If they're even a little offset from 1, then a NaN won't be returned, e.g.

Yeah figured it had something to do with that, but was then surprised to see one of them not being affected by this :shrug:

An interesting question is what should ess_rhat do when given a matrix of identical numbers.

Tbh it's going to be quite difficult to figure out why variables are "actually" sampled, and which aren't, and hence customizing the resulting behavior in Chains is probably something that we have to defer to quite far into the future :confused:

sethaxen commented 1 year ago

If they're even a little offset from 1, then a NaN won't be returned, e.g.

Yeah figured it had something to do with that, but was then surprised to see one of them not being affected by this shrug

Ah, that happens because the first element won't have any numerical error. It's computed as r1'r1 where r1 is the first column of the upper factor, whose first element is explicitly set to 1 and whose other elements are explicitly set to 0.

An interesting question is what should ess_rhat do when given a matrix of identical numbers.

Tbh it's going to be quite difficult to figure out why variables are "actually" sampled, and which aren't, and hence customizing the resulting behavior in Chains is probably something that we have to defer to quite far into the future confused

FWIW I think this is an unrelated issue to what MCMCDiagnosticTools should do when it gets a scalar array. But this is a tangent for this issue.

torfjelde commented 1 year ago

Ah, that happens because the first element won't have any numerical error. It's computed as r1'r1 where r1 is the first column of the upper factor, whose first element is explicitly set to 1 and whose other elements are explicitly set to 0.

Aaah true! Thanks!