TuringLang / ParetoSmooth.jl

An implementation of PSIS algorithms in Julia.
http://turinglang.org/ParetoSmooth.jl/
MIT License
19 stars 12 forks source link

LOO Comparison Function #13

Closed ParadaCarleton closed 3 years ago

ParadaCarleton commented 3 years ago

We should provide a function that compares two models, similar to loo_compare. This should be pretty easy to build.

ParadaCarleton commented 3 years ago

@goedman I believe you already had something similar implemented in StatsModelComparisons, care to make a pull request adding it?

goedman commented 3 years ago

I'll have a look.

goedman commented 3 years ago

The comments I posted in the closed PR is better located with this issue, so I reposted.

The compare() method I wrote for StatisticalRethinking produces something like this:

julia> df_psis = compare([m5_1s, m5_2s, m5_3s], :psis)
3×8 DataFrame
 Row │ models  PSIS     lppd     SE       dPSIS    dSE      pPSIS    weight  
     │ String  Float64  Float64  Float64  Float64  Float64  Float64  Float64 
─────┼───────────────────────────────────────────────────────────────────────
   1 │ m5.1s     125.4   118.52    12.39      0.0     0.0      3.67     0.66
   2 │ m5.3s     126.7   118.09    12.44      1.3     0.7      4.62     0.34
   3 │ m5.2s     138.8   133.37     9.68     13.4     8.86     2.92     0.0

My first question: Is this, of course based on the ParetoSmooth's loo() method, what you are looking for?

I used Dataframes and would need to convert to AxisKeys.jl format. A good exercise for me as I only recently looked into AxisKeys.jl. And obtain the data from the Psisloo object.

You also mention support for Stan.jl. We could indeed add StanSample.jl to the test Project.toml and install Stan's cmdstan in the CI script. Technically simple to do and robust.

The above 3 models (m5.1s, m5.3s and m5.2s) investigate the association of the median age at marriage (A), the marriage rate (M) with the divorce rate (D) in the southern US states. Here:

m5.1s: D ~ A
m5.3s: D ~ A + M
m5.2s: D ~ M

E.g., for m5.3s the StanLanguage programs is:

stan5_3 = "
data {
  int N;
  vector[N] D;
  vector[N] M;
  vector[N] A;
}
parameters {
  real a;
  real bA;
  real bM;
  real<lower=0> sigma;
}
transformed parameters {
    vector[N] mu;
    mu = a + + bA * A + bM * M;
}
model {
  a ~ normal( 0 , 0.2 );
  bA ~ normal( 0 , 0.5 );
  bM ~ normal( 0 , 0.5 );
  sigma ~ exponential( 1 );
  D ~ normal( mu , sigma );
}
generated quantities{
    vector[N] log_lik;
    for (i in 1:N)
        log_lik[i] = normal_lpdf(D[i] | mu[i], sigma);
}
";

and I extract the log_lik matrix as before, e.g. for above Stan Language model (stan5_3):

# Either add WaffleDivorce.csv data to the test directory or simulate the association
df = CSV.read(sr_datadir("WaffleDivorce.csv"), DataFrame);
# Replace scale!() with zscore() or something.
scale!(df, [:Marriage, :MedianAgeMarriage, :Divorce])                                   
data = (N=size(df, 1), D=df.Divorce_s, A=df.MedianAgeMarriage_s, M=df.Marriage_s)

m5_3s = SampleModel("m5.3s", stan5_3)
rc5_3s = stan_sample(m5_3s; data)

if success(rc5_3s)
    st5_3s = read_samples(m5_3s; output_format=:table);
    log_lik = matrix(st5_3s, "log_lik")
    ll = reshape(Matrix(log_lik'), data.N, m5_3s.method.num_samples, m5_3s.n_chains[1]);
    m5_3s_loo = ParetoSmooth.loo(ll)
end

My 2nd question is how strongly you would like to demonstrate Stan.jl versus to use the equivalent Turing models and use Chris' additions and possible simulate the association. We can still, on a high level, document how to use this package with Stan.

ParadaCarleton commented 3 years ago

@goedman This would be perfect, thanks! The documentation here should mostly be about this package, rather than Stan or Turing; I suspect more people will be using this package with Turing, though, so I think a tutorial should use Turing models. Since it looks like you use lppd, I'll add a column with that today. Feel free to modify how the psis_loo function works to return more useful values if you want.

goedman commented 3 years ago

Hi Chris (@itsdfish),

Trying to test the initial version of a loo_compare function.

With Stan I get:

   8 │ a               -0.000148878    0.00167179  0.100784    -0.167386    0.00125757    0.160293  3634.3
   9 │ bA              -0.567067       0.00191658  0.114245    -0.756436   -0.567159     -0.371178  3553.2   ⋯
  10 │ sigma            0.820792       0.00142859  0.0850291    0.69417     0.813606      0.969412  3542.59

3-element Vector{PsisLoo}:
 ┌ Warning: Some Pareto k values are very high (>0.7), indicating that PSIS has failed to approximate the true distribution.
└ @ ParetoSmooth ~/.julia/packages/ParetoSmooth/uBHCv/src/LooStructs.jl:78
┌─────────────┬────────┬──────────┬───────┬─────────┐
│             │  total │ se_total │  mean │ se_mean │
├─────────────┼────────┼──────────┼───────┼─────────┤
│   loo_score │ -63.17 │     6.64 │ -1.26 │    0.13 │
│ naive_score │ -59.21 │     4.89 │ -1.18 │    0.10 │
│     overfit │   3.96 │     2.04 │  0.08 │    0.04 │
└─────────────┴────────┴──────────┴───────┴─────────┘

┌───────┬────────┬───────┬───────┬───────┬─────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weights │
├───────┼────────┼───────┼───────┼───────┼─────────┤
│ m5.1s │ -63.17 │  6.64 │  0.00 │  0.00 │    0.69 │
│ m5.3s │ -63.98 │  6.58 │ -0.81 │ -0.07 │    0.31 │
│ m5.2s │ -69.70 │  4.95 │ -6.53 │ -1.69 │    0.00 │
└───────┴────────┴───────┴───────┴───────┴─────────┘

I don't trust the ΔSE values yet.

Not sure what I'm doing wrong with Turing though, would you mine taking a quick look?

using Turing, MCMCChains
using ParetoSmooth
using StatisticalRethinking: sr_datadir

df = CSV.read(sr_datadir("WaffleDivorce.csv"), DataFrame)
df.D = zscore(df.Divorce)
df.M = zscore(df.Marriage)
df.A = zscore(df.MedianAgeMarriage)
data = (D=df.D, A=df.A)

function compute_loglike(μ, data)
    return logpdf(Normal(μ, 1), data)
end

@model function m5_1_A(A, D)
    a ~ Normal(0, 0.2)
    bA ~ Normal(0, 0.5)
    σ ~ Exponential(1)
    μ = lin(a, A, bA)
    D ~ MvNormal(μ, σ)
end

chn5_1_A = sample(m5_1_A(df.A, df.D), NUTS(1000, .9), MCMCThreads(), 1000, 4)
chn5_1_A |> display
pw_lls = pointwise_log_likelihoods(m5_1_A(df.A, df.D), chn5_1_A)
size(pw_lls) |> display
psis_loo_output = psis_loo(m5_1_A(df.A, df.D), chn5_1_A)
psis_output = psis(compute_loglike, chn5_1_A, data)

returns:

Chains MCMC chain (1000×15×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 2.45 seconds
Compute duration  = 8.46 seconds
parameters        = a, σ, bA
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

           a    0.0005    0.0998     0.0016    0.0014   3482.4509    0.9998      411.5886
          bA   -0.5642    0.1109     0.0018    0.0019   3568.9684    0.9999      421.8140
           σ    0.8210    0.0848     0.0013    0.0011   3478.2494    0.9995      411.0920

┌ Info: **Important Note:** The posterior log-likelihood must be computed with a `for` loop inside a
│ Turing model; broadcasting will result in all observations being treated as if they are a
└ single point. 

(1, 1000, 4)

and:

julia> psis_loo_output
┌ Warning: Some Pareto k values are very high (>0.7), indicating that PSIS has failed to approximate the true distribution.
└ @ ParetoSmooth ~/.julia/packages/ParetoSmooth/uBHCv/src/LooStructs.jl:78
┌─────────────┬────────┬──────────┬────────┬─────────┐
│             │  total │ se_total │   mean │ se_mean │
├─────────────┼────────┼──────────┼────────┼─────────┤
│   loo_score │ -62.24 │      NaN │ -62.24 │     NaN │
│ naive_score │ -60.39 │      NaN │ -60.39 │     NaN │
│     overfit │   1.85 │      NaN │   1.85 │     NaN │
└─────────────┴────────┴──────────┴────────┴─────────┘

The final check in the Turing test script also fails.

itsdfish commented 3 years ago

@goedman,

I think the first problem can be traced back to a warning that you overlooked:

┌ Info: **Important Note:** The posterior log-likelihood must be computed with a `for` loop inside a
│ Turing model; broadcasting will result in all observations being treated as if they are a
└ single point. 

Unfortunately, Turing does not save pointwise log likelihoods when broadcasting is used in the model likelihood. Instead, they are summed across observations. You can see this problem in the size of the first dimension of pw_lls, which is 1 instead of 50:

julia> pw_lls
1×1000×4 Array{Float64, 3}:

The solution unfortunately is to use a for loop in the Turing model or to redefine the likelihood (e.g. compute_loglike in your example). See this issue for details.

The reason that psis_output = psis(compute_loglike, chn5_1_A, data) fails is because compute_loglike does not match the definition in your Turing model (i.e. it does not take in the right parameters). You can fix that by redefining compute_loglike such that it returns the log likelihood of a single data point given the appropriate parameters.

ParadaCarleton commented 3 years ago

@goedman,

I think the first problem can be traced back to a warning that you overlooked:

┌ Info: **Important Note:** The posterior log-likelihood must be computed with a `for` loop inside a
│ Turing model; broadcasting will result in all observations being treated as if they are a
└ single point. 

Unfortunately, Turing does not save pointwise log likelihoods when broadcasting is used in the model likelihood. Instead, they are summed across observations. You can see this problem in the size of the first dimension of pw_lls, which is 1 instead of 50:

julia> pw_lls
1×1000×4 Array{Float64, 3}:

The solution unfortunately is to use a for loop in the Turing model or to redefine the likelihood (e.g. compute_loglike in your example). See this issue for details.

If Turing is updated to the newest version, the error described in the warning should be fixed. The thread you linked is about a separate issue (.~ broadcasts over columns, rather than rows, even though tables usually have each row being IID).

goedman commented 3 years ago

Thanks guys,

I'm now getting:

┌───────┬────────┬───────┬───────┬───────┬────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weight │
├───────┼────────┼───────┼───────┼───────┼────────┤
│ m5_1t │ -63.00 │  6.55 │  0.00 │  0.00 │   0.67 │
│ m5_3t │ -63.69 │  6.44 │ -0.69 │ -0.10 │   0.33 │
│ m5_2t │ -69.68 │  4.96 │ -6.68 │ -1.59 │   0.00 │
└───────┴────────┴───────┴───────┴───────┴────────┘

vs. above Stan results:


┌───────┬────────┬───────┬───────┬───────┬─────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weights │
├───────┼────────┼───────┼───────┼───────┼─────────┤
│ m5.1s │ -63.17 │  6.64 │  0.00 │  0.00 │    0.69 │
│ m5.3s │ -63.98 │  6.58 │ -0.81 │ -0.07 │    0.31 │
│ m5.2s │ -69.70 │  4.95 │ -6.53 │ -1.69 │    0.00 │
└───────┴────────┴───────┴───────┴───────┴─────────┘

Edit: This looks pretty reasonable. I'm a bit surprised by occasionally wild jumps in the Turing estimates if I don't use Random.seed!() so I'm experimenting a bit more with that. I'm also looking into the ΔSE which in Statistical Rethinking is computed quite differently.

ParadaCarleton commented 3 years ago

Thanks guys,

I'm now getting:

┌───────┬────────┬───────┬───────┬───────┬────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weight │
├───────┼────────┼───────┼───────┼───────┼────────┤
│ m5_1t │ -63.00 │  6.55 │  0.00 │  0.00 │   0.67 │
│ m5_3t │ -63.69 │  6.44 │ -0.69 │ -0.10 │   0.33 │
│ m5_2t │ -69.68 │  4.96 │ -6.68 │ -1.59 │   0.00 │
└───────┴────────┴───────┴───────┴───────┴────────┘

vs. above Stan results:


┌───────┬────────┬───────┬───────┬───────┬─────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weights │
├───────┼────────┼───────┼───────┼───────┼─────────┤
│ m5.1s │ -63.17 │  6.64 │  0.00 │  0.00 │    0.69 │
│ m5.3s │ -63.98 │  6.58 │ -0.81 │ -0.07 │    0.31 │
│ m5.2s │ -69.70 │  4.95 │ -6.53 │ -1.69 │    0.00 │
└───────┴────────┴───────┴───────┴───────┴─────────┘

Edit: This looks pretty reasonable. I'm a bit surprised by occasionally wild jumps in the Turing estimates if I don't use Random.seed!() so I'm experimenting a bit more with that. I'm also looking into the ΔSE which in Statistical Rethinking is computed quite differently.

Yep, it should be -- at the moment, it looks like you're computing the difference of the standard errors, rather than the standard error of the differences.

Worth noting that the most recent PR I just merged should include additional information, like the in-sample score (lpd), that should make it easier to provide all of the information in the Statistical Rethinking comparison method. (Although it might break some of your code a bit, unfortunately; sorry about that :sweat_smile: )

ParadaCarleton commented 3 years ago

Standard errors should be fairly easy to compute -- it's just the regular formula for the standard error of a mean. (Take the pointwise difference in LOO scores, use std to get the standard deviation, and then divide by sqrt(n) for the average score. Getting from the average to the total score is just a scaling transformation where you multiply by n, so the standard error also gets scaled by n.)

goedman commented 3 years ago

In StatsModelComparisons.jl I used:

    dse = zeros(nmodels)
    for i in 2:nmodels
        diff = loos[1] - loos[i]
        dse[i] = √(length(loos[i]) * var(diff; corrected=false))
    end

I think that matches your description and gives:

┌───────┬────────┬───────┬───────┬───────┬────────┐
│       │   PSIS │    SE │ ΔPSIS │   ΔSE │ weight │
├───────┼────────┼───────┼───────┼───────┼────────┤
│ m5_1t │ -63.00 │  6.55 │  0.00 │  0.00 │   0.67 │
│ m5_3t │ -63.69 │  6.44 │ -0.69 │  0.42 │   0.33 │
│ m5_2t │ -69.68 │  4.96 │ -6.68 │  4.74 │   0.00 │
└───────┴────────┴───────┴───────┴───────┴────────┘

Do you have a preference as far as header sequence is concerned? I was thinking: PSIS ΔPSIS SE ΔSE lpd weight

ParadaCarleton commented 3 years ago

@goedman I would follow the headers used in the loo object itself to keep some consistency, with the exception that the weights for each model should be added as the final column. I'd avoid the Δ, which can be annoying to type out for some people who aren't using VSCode.

For the comparison object, I think it's better to only have the comparisons, rather than the LOO-CV scores themselves, since otherwise it encourages people to try and interpret the raw scores (which aren't meaningful).

goedman commented 3 years ago

You mean with LooCompare defined as:

struct LooCompare
    psis::Vector{PsisLoo}
    table::KeyedArray
end

loo = loo_compare([pw_lls5_1t, pw_lls5_2t, pw_lls5_3t]; model_names=[:m5_1t, :m5_2t, :m5_3t])

something like this:

julia> loo.psis
3-element Vector{PsisLoo}:
 ┌ Warning: Some Pareto k values are very high (>0.7), indicating that PSIS has failed to approximate the true distribution.
└ @ ParetoSmooth ~/.julia/dev/ParetoSmooth/src/LooStructs.jl:96
Results of PSIS-LOO-CV with 4000 Monte Carlo samples and 50 data points.
┌───────────┬────────┬──────────┬───────┬─────────┐
│           │  total │ se_total │  mean │ se_mean │
├───────────┼────────┼──────────┼───────┼─────────┤
│   loo_est │ -63.00 │     6.55 │ -1.26 │    0.13 │
│ naive_est │ -59.22 │     4.89 │ -1.18 │    0.10 │
│   overfit │   3.78 │     1.93 │  0.08 │    0.04 │
└───────────┴────────┴──────────┴───────┴─────────┘

 Results of PSIS-LOO-CV with 4000 Monte Carlo samples and 50 data points.
┌───────────┬────────┬──────────┬───────┬─────────┐
│           │  total │ se_total │  mean │ se_mean │
├───────────┼────────┼──────────┼───────┼─────────┤
│   loo_est │ -69.68 │     4.96 │ -1.39 │    0.10 │
│ naive_est │ -66.70 │     4.16 │ -1.33 │    0.08 │
│   overfit │   2.98 │     0.92 │  0.06 │    0.02 │
└───────────┴────────┴──────────┴───────┴─────────┘

 [ Info: Some Pareto k values are slightly high (>0.5); some pointwise estimates may be slow to converge or have high variance.
Results of PSIS-LOO-CV with 4000 Monte Carlo samples and 50 data points.
┌───────────┬────────┬──────────┬───────┬─────────┐
│           │  total │ se_total │  mean │ se_mean │
├───────────┼────────┼──────────┼───────┼─────────┤
│   loo_est │ -63.69 │     6.44 │ -1.27 │    0.13 │
│ naive_est │ -59.03 │     4.68 │ -1.18 │    0.09 │
│   overfit │   4.66 │     1.90 │  0.09 │    0.04 │
└───────────┴────────┴──────────┴───────┴─────────┘

julia> loo
┌───────┬────────┬───────┬────────┐
│       │ d_PSIS │  d_SE │ weight │
├───────┼────────┼───────┼────────┤
│ m5_1t │   0.00 │  0.00 │   0.67 │
│ m5_3t │  -0.69 │  0.42 │   0.33 │
│ m5_2t │  -6.68 │  4.74 │   0.00 │
└───────┴────────┴───────┴────────┘
goedman commented 3 years ago

Just a few thoughts/observations:

  1. Should the last Pareto k info message "slightly high (>0.5)" also be a warning?

  2. In setting up the test environment I always get:

    Julia> include("/Users/rob/.julia/dev/ParetoSmooth/test/loo_compare_test.jl");
    ┌ Warning: Error requiring `MCMCChains` from `ParetoSmooth`
    │   exception =
    │    cannot assign a value to variable MCMCChains.MCMCChains from module ParetoSmooth

    There are definitely advantages in this setup with a Project.toml in the test directory (after adding ParetoSmooth to that Project.toml you can activate the test environment and run scripts and inspect results). But I don't think doing it this is fully production ready yet. The scripts do run after above message.

  3. Currently we can't run the ParetoSmooth tests on Julia 1.7 & 1.8 but using Stan I get an error in KeyAxis.jl. I'll try to create a MWE for that. It happens while printing out the PsisLoo object.

  4. I think the format of the Manifest.toml in ParetoSmooth is the old format:

    (ParetoSmooth) pkg> up
    Updating registry at `~/.julia/registries/General`
    Updating git-repo `https://github.com/JuliaRegistries/General.git`
    ┌ Warning: The active manifest file at `/Users/rob/.julia/dev/ParetoSmooth/Manifest.toml` has an old format that is being maintained.
    │ To update to the new format run `Pkg.upgrade_manifest()` which will upgrade the format without re-resolving.
ParadaCarleton commented 3 years ago

Implemented!