TuringLang / ParetoSmooth.jl

An implementation of PSIS algorithms in Julia.
http://turinglang.org/ParetoSmooth.jl/
MIT License
19 stars 12 forks source link

Make ParetoSmooth "Just Work" With Turing and Stan #9

Closed ParadaCarleton closed 3 years ago

ParadaCarleton commented 3 years ago

This would require methods accepting MCMCChains objects and Stanfit objects; they should probably take the object, extract an array of log-likelihood values, and then call the base method for arrays on it. @itsdfish I believe you were working on this; would you be willing to handle it? See also this.

itsdfish commented 3 years ago

Yeah. I can do that. Last week I added a temporary function in Rob's repo. Would you like me to make a pull request for this repo?

goedman commented 3 years ago

For testing purposes doing a PR is fine, but I would strongly (very strongly?) suggest not to make Stan, DynamicHMC or Turing a dependency for the use of this package. I think it would be sufficient to pass in an array of [params, draws, chains] or a possible permutation of that.

itsdfish commented 3 years ago

Hi Rob,

I think your suggestion makes sense. In fact, I do not think there is a reason to use DynamicHMC or Turing in the tests either. As long as PareotoSmooth computes the quantities correctly, we can use random data to make sure the methods work. For example, in your test here , the chain can be fill with random samples instead of being returned by Turing. That would eliminate the need to use any sampling program. It would only require MCMCChains.

I'm not familiar with Stanfit objects, but I think a similar approach would apply.

ParadaCarleton commented 3 years ago

Requires.jl should let us add methods dealing with inputs from Turing and Stan without making either one a dependency; I think we should have methods dealing with inputs from those two without the need for users to do anything except call the method on their MCMCChains/Stanfit object. Would you be ok with making a PR for this @itsdfish?

goedman commented 3 years ago

Requires.jl is definitely one way to go. An alternative is maybe including a good Turing example in the TestSet? Currently, in the cars_turing.jl example, pointwise_loglikes() returns a simple array: points, samples, chains.

As far as StanSample is concerned, if computed in the generated_quantities section of the Stan Language program, there are a few ways obtaining a log_lik matrix. Right now I've mostly used the nt = read_samples(model; output_format=:named tuple). As this output_format is the default, after:

nt = read_samples(model)

nt.log_lik will hold the log_lik matrix. By default (in this example) 50 points x 4000 samples which need a reshape to separate the chains:

    ll = reshape(nt_cars.log_lik, 50, 1000, 4);
    cars_psis = psis(ll);
    cars_loo = ParetoSmooth.loo(ll)
    cars_loo.estimates |> display
    cars_loo.pointwise |> display

    if isdefined(Main, :StatisticalRethinking)
        pk_plot(cars_loo.psis_object.pareto_k)
        savefig(joinpath(ProjDir, "pareto_k_plot.png"))
        pk_plot(pk)
        savefig(joinpath(ProjDir, "pk_plot.png"))
        closeall()
    end

Since StanSample v3.1.0 there is also the output_format=:table:

    st = read_samples(m11_4s; output_format=:table);      # Create a StanTable object, in this example
                                                                                                # 504 points, 1000 samples/chain, 4 chains
    log_lik = matrix(st, "log_lik")
    n_sam, n_obs = size(log_lik)                                            # log_lik is 4000 x 504, transpose and split chains:
    ll = reshape(Matrix(log_lik'), 504, 1000, 4);
    chimpanzees_psis = psis(ll);
    chimpanzees_loo = ParetoSmooth.loo(ll)
    chimpanzees_loo.estimates |> display
    chimpanzees_loo.pointwise |> display

For StanSample v4 I plan to also support a KeyedArray for the parameters/points, draws/samples, chains.

ParadaCarleton commented 3 years ago

Sounds good; I'd still definitely prefer to have a method to handle MCMCChains without any human intervention, though. (It's worth noting that, even without Requires.jl, just adding MCMCChains would not necessitate installing all of Turing. In fact, it's already a dependency, since I'm relying on it for ESS calculations. That being said, ESS calculations are being split off into another package, so using Requires would let us replace it with a smaller package.)

goedman commented 3 years ago

Correct, MCMCChains is fine as a dependency.

This is how can I use it once I create an MCMCChains in StanSample.jl:

    chn = read_samples(cars_stan_model; output_format=:mcmcchains);
    log_lik2 = Matrix(Array(chn)[:, 54:end]');
    ll2 = reshape(log_lik2, 50, 1000, 4);
    psis_ll2 = psis(ll2);

    cars_loo2 = ParetoSmooth.loo(ll2)
    println()
    cars_loo2.estimates |> display
    println()
    cars_loo2.pointwise |> display

Both the Cars and the Chimpanzees example now give great results with ParetoSmooth v0.2.0.

Chris, correctly if I am wrong, but I think in Turing the log_liks are not part of the MCMCChains object.

itsdfish commented 3 years ago

@goedman, you are correct. As you can see in my pull request, the pointwise log likelihoods are computed from the Turing model and the Chain object. I also included a method to compute loo from a function or to pass the computed pointwise log likelihoods directly. My hope is that this sect of methods covers most use cases.

goedman commented 3 years ago

Thanks Chris, yes will take a closer look at your PR.

For StanSample.jl I still prefer the :table approach:

    st11_4s = read_samples(m11_4s; output_format=:table);
    log_lik = matrix(st11_4s, "log_lik")

So given that most of StatsModelComparisons.jl functionality is now outdated by ParetoSmooth.jl I plan to move the WAIC and pk_plot functionality to StatisticalRethinking.jl and deprecate that project. Of course it will remain available for at least 1 or 2 more years or until Julia v2 is released. Would you be ok with that?

itsdfish commented 3 years ago

Rob,

That sounds like a reasonable plan.

Assuming log_like, is a three dimensional array of pointwise log likelihoods, you can call this method. Does that work for you use case?