Open sethaxen opened 2 years ago
I don't know if this will have any negative impact on sampling.
After further thought, yes, this is detrimental to sampling with NUTS, because it makes it very unlikely that a u-turn will be encountered, so the maximum tree depth of 10 is usually saturated:
julia> mean(==(10), chns[:tree_depth])
0.961
So 96% of transitions saturated the tree depth (took 1024 leapfrog steps). Compare with the same model with Stan:
julia> using CmdStan, MCMCChains
julia> code = """
data {
int<lower=1> n;
real<lower=0> eta;
}
parameters {
corr_matrix[n] R;
}
model {
R ~ lkj_corr(eta);
}
""";
julia> stanmodel = Stanmodel(name="lkj", model=code, output_format=:mcmcchains);
...
julia> _, chns, _ = stan(stanmodel, Dict("n" => 3, "eta" => 2), "./");
...
Inference for Stan model: lkj_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (0.013, 0.014, 0.014, 0.014) seconds, 0.055 seconds total
Sampling took (0.027, 0.026, 0.029, 0.026) seconds, 0.11 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -1.7e+00 3.4e-02 1.4e+00 -4.5 -1.3e+00 -0.20 1.7e+03 1.6e+04 1.0e+00
accept_stat__ 0.91 1.8e-03 0.12 0.64 0.95 1.0 4.7e+03 4.4e+04 1.0e+00
stepsize__ 0.80 5.6e-02 0.080 0.67 0.85 0.88 2.0e+00 1.9e+01 2.8e+13
treedepth__ 2.2 9.1e-02 0.53 1.0 2.0 3.0 3.4e+01 3.1e+02 1.0e+00
n_leapfrog__ 4.5 3.5e-01 2.0 3.0 3.0 7.0 3.4e+01 3.1e+02 1.0e+00
divergent__ 0.00 nan 0.00 0.00 0.00 0.00 nan nan nan
energy__ 3.2 5.0e-02 1.9 0.87 2.8 6.8 1.5e+03 1.3e+04 1.0e+00
R[1,1] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
R[1,2] 4.0e-03 7.0e-03 4.1e-01 -0.67 -3.2e-03 0.68 3.5e+03 3.2e+04 1.0e+00
R[1,3] 2.9e-03 6.6e-03 4.1e-01 -0.68 2.2e-03 0.67 3.9e+03 3.6e+04 1.0e+00
R[2,1] 4.0e-03 7.0e-03 4.1e-01 -0.67 -3.2e-03 0.68 3.5e+03 3.2e+04 1.0e+00
R[2,2] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
R[2,3] -1.7e-02 7.1e-03 4.1e-01 -0.69 -1.5e-02 0.66 3.3e+03 3.1e+04 1.0e+00
R[3,1] 2.9e-03 6.6e-03 4.1e-01 -0.68 2.2e-03 0.67 3.9e+03 3.6e+04 1.0e+00
R[3,2] -1.7e-02 7.1e-03 4.1e-01 -0.69 -1.5e-02 0.66 3.3e+03 3.1e+04 1.0e+00
R[3,3] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).
julia> mean(==(10), chns[:treedepth__])
0.0
julia> mean(chns[:treedepth__])
2.203
No transitions hit the max tree depth, and on average each transition took only 2 tree doublings (7 leapfrog steps).
As demonstrated in https://discourse.julialang.org/t/case-study-speeding-up-a-logistic-regression-with-rhs-prior-turing-vs-numpyro-any-tricks-im-missing/87681/34, this impropriety also introduces post-warmup numerical error.
A quick-and-dirty hack to get this working for Turing users might be to edit https://github.com/TuringLang/Bijectors.jl/blob/b20471252c01dd4832e06aa80045046483f3804e/src/bijectors/corr.jl#L83-L95 to add
for j in 1:K, i in j:K
result -= y[i, j]^2 / 2
end
This puts a standard normal prior on the extra DOFs. The downside is that it technically lies about the logdetjac (but so does the current implementation, which lies about the transform being bijective), and the logdetjac of the inverse function will disagree, but this is a band-aid that will work until Bijectors can support inputs and outputs of different dimensions.
Okay, so combining https://github.com/TuringLang/Bijectors.jl/pull/246 and https://github.com/TuringLang/DynamicPPL.jl/pull/462, this now works:
julia> using Turing, Random
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]
julia> @model function model(n, η)
R ~ LKJ(n, η)
end;
julia> mod = model(3, 2.0);
julia> Random.seed!(50);
julia> chns = sample(mod, NUTS(0.99), 1_000; save_state=true)
┌ Info: Found initial step size
└ ϵ = 1.6
Sampling 100%|██████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:09
Chains MCMC chain (1000×21×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 11.07 seconds
Compute duration = 11.07 seconds
parameters = R[1,1], R[2,1], R[3,1], R[1,2], R[2,2], R[3,2], R[1,3], R[2,3], R[3,3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
R[1,1] 1.0000 0.0000 0.0000 0.0000 NaN NaN NaN
R[2,1] -0.0274 0.4129 0.0131 0.0153 773.7963 0.9994 69.8877
R[3,1] 0.0037 0.4011 0.0127 0.0147 642.2541 0.9992 58.0071
R[1,2] -0.0274 0.4129 0.0131 0.0153 773.7963 0.9994 69.8877
R[2,2] 1.0000 0.0000 0.0000 0.0000 996.8495 0.9990 90.0334
R[3,2] 0.0077 0.4061 0.0128 0.0157 670.0528 1.0039 60.5178
R[1,3] 0.0037 0.4011 0.0127 0.0147 642.2541 0.9992 58.0071
R[2,3] 0.0077 0.4061 0.0128 0.0157 670.0528 1.0039 60.5178
R[3,3] 1.0000 0.0000 0.0000 0.0000 672.7942 0.9990 60.7654
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
R[1,1] 1.0000 1.0000 1.0000 1.0000 1.0000
R[2,1] -0.7647 -0.3269 -0.0286 0.2606 0.7454
R[3,1] -0.7393 -0.3006 -0.0072 0.3158 0.7480
R[1,2] -0.7647 -0.3269 -0.0286 0.2606 0.7454
R[2,2] 1.0000 1.0000 1.0000 1.0000 1.0000
R[3,2] -0.7260 -0.2919 0.0068 0.2974 0.7714
R[1,3] -0.7393 -0.3006 -0.0072 0.3158 0.7480
R[2,3] -0.7260 -0.2919 0.0068 0.2974 0.7714
R[3,3] 1.0000 1.0000 1.0000 1.0000 1.0000
Though I'm confused as to why MCMCChains
reports a non-NaN ESS for some of the diagonal elements (I've checked the underlying values and they are indeed fixed to 1).
Okay, so combining #246 and TuringLang/DynamicPPL.jl#462, this now works:
Awesome! Can you also check the estimated metric?
Though I'm confused as to why
MCMCChains
reports a non-NaN ESS for some of the diagonal elements (I've checked the underlying values and they are indeed fixed to 1).
If they're even a little offset from 1, then a NaN
won't be returned, e.g.
julia> using MCMCDiagnosticTools
julia> x = ones(1_000, 1, 10);
julia> MCMCDiagnosticTools.ess_rhat(x)[1]
10-element Vector{Float64}:
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
julia> MCMCDiagnosticTools.ess_rhat(x .+ eps() .* randn.())[1]
10-element Vector{Float64}:
955.6941303601072
713.5937319411215
1092.1539947505667
903.5876039869821
925.0133494392358
894.2614035388227
793.3188327568754
1058.2749400295972
988.8133813971027
1096.6615583111022
In #246, the final R
matrix is generated from the Cholesky factor through matrix multiplication, so the diagonals may not be exactly 1.
An interesting question is what should ess_rhat
do when given a matrix of identical numbers. ArviZ returns an ESS equivalent to the sample size. posterior returns an NA
. I tend towards returning NaN
since if chains are completely stalled, all values will be identical, but we don't want to return a high ESS. But this creates the odd behavior that as the posterior variance goes to 0 we first approach the sample size, then suddenly return NaN
.
Awesome! Can you also check the estimated metric?
Here it is:
julia> chns.info.samplerstate.hamiltonian.metric.M⁻¹
3-element Vector{Float64}:
0.25406253280310975
0.22661105092259304
0.2953807290308163
The underlying tracked variables are now the subspace btw, if that's what you were wondering.
If they're even a little offset from 1, then a NaN won't be returned, e.g.
Yeah figured it had something to do with that, but was then surprised to see one of them not being affected by this :shrug:
An interesting question is what should ess_rhat do when given a matrix of identical numbers.
Tbh it's going to be quite difficult to figure out why variables are "actually" sampled, and which aren't, and hence customizing the resulting behavior in Chains
is probably something that we have to defer to quite far into the future :confused:
If they're even a little offset from 1, then a NaN won't be returned, e.g.
Yeah figured it had something to do with that, but was then surprised to see one of them not being affected by this shrug
Ah, that happens because the first element won't have any numerical error. It's computed as r1'r1
where r1
is the first column of the upper factor, whose first element is explicitly set to 1 and whose other elements are explicitly set to 0.
An interesting question is what should ess_rhat do when given a matrix of identical numbers.
Tbh it's going to be quite difficult to figure out why variables are "actually" sampled, and which aren't, and hence customizing the resulting behavior in
Chains
is probably something that we have to defer to quite far into the future confused
FWIW I think this is an unrelated issue to what MCMCDiagnosticTools should do when it gets a scalar array. But this is a tangent for this issue.
Ah, that happens because the first element won't have any numerical error. It's computed as r1'r1 where r1 is the first column of the upper factor, whose first element is explicitly set to 1 and whose other elements are explicitly set to 0.
Aaah true! Thanks!
An $n \times n$ correlation matrix has ${n \choose 2} = \frac{n (n-1)}{2}$ degrees of freedom. This is the same as the number of elements in a strict upper triangular $n \times n$ matrix. The
CorrBijector
works by mapping from the correlation matrix first to its unique upper Cholesky factor and then to a strictly upper triangular matrix of unconstrained entries.The trouble is that in the unconstrained space, we now have $n \times n$ parameters, of which ${n+1 \choose 2} = \frac{n(n+1)}{2}$ have no impact on the log density. These extra parameters have an implicit improper uniform prior on the reals, which makes the posterior distribution in unconstrained space improper. Because these parameters have infinite variance, during adaptation, HMC will learn this, and they will explode in value. I don't know if this will have any negative impact on sampling.
In this demo, we're sampling the uniform distribution on the correlation matrices.
Note the number of parameters. We should have
3*(3-1)/2 = 3
DOFs, but instead we have3*3
. And note that3*(3+1)/2=6
of the degrees of freedom have adapted variances of ~1e20.There are several ways to solve this, neither of which seem to be possible in Bijectors right now:
logabsdetjac
to contain this prior term. However, when I tried this, it seemed to have no effect, since I guesslogabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
is being called instead oflogabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
. When mapping fromX
toy
, these extra parameters are all set to 0, so we have no way of setting this prior.