TuringLang / ParetoSmooth.jl

An implementation of PSIS algorithms in Julia.
http://turinglang.org/ParetoSmooth.jl/
MIT License
19 stars 12 forks source link

Helper functions #6

Closed itsdfish closed 3 years ago

itsdfish commented 3 years ago

Hi @ParadaCarleton and @goedman

@ParadaCarleton, thanks again for putting together this package.

Given that I have very limited knowledge of PSIS and LOO, I may not be the best person for this task. Nonetheless, based on what I can piece together, what appears to be missing are helper functions for computing pointwise log likelihoods and the loo value from the psis object. Here is my best attempt. Am I on the right track?

using Turing, MCMCChains, Distributions, ParetoSmooth, Random
using StatsFuns

Random.seed!(5574)

data = rand(Normal(0, 1), 30)

@model function model(y)
    μ ~ Normal(0, 1)
    σ ~ truncated(Cauchy(0, 1), 0.0, Inf)

    y .~ Normal(μ, σ)
end

chains = sample(model(data), NUTS(1000, .65), MCMCThreads(), 1000, 3)

# method for MCMCChains
function pointwise_loglikes(chain::Chains, data, ll_fun)
    samples = Array(Chains(chain, :parameters).value)
    pointwise_loglikes(samples, data, ll_fun)
end

# generic method for arrays
function pointwise_loglikes(samples::Array{Float64,3}, data, ll_fun)
    n_data = length(data)
    n_samples, n_chains = size(samples)[[1,3]]
    pointwise_lls = fill(0.0, n_data, n_samples, n_chains)
    for c in 1:n_chains 
        for s in 1:n_samples
            for d in 1:n_data
                pointwise_lls[d,s,c] = ll_fun(samples[s,:,c], data[d])
            end
        end
    end
    return pointwise_lls
end

function compute_loo(psis_output, pointwise_lls)
    dims = size(pointwise_lls)
    lwp = deepcopy(pointwise_lls)
    lwp += psis_output.weights;
    lwpt = reshape(lwp, dims[1], dims[2] * dims[3])';
    loos = reshape(logsumexp(lwpt; dims=1), size(lwpt, 2));
    return sum(loos)
end

# compute the pointwise log likelihoods where indices correspond to [data, sample, chain]
pointwise_lls = pointwise_loglikes(chains, data, (p,d)->logpdf(Normal(p...), d))

# compute the psis object
psis_output = psis(pointwise_lls)

# return loo based on Rob's example
loo = compute_loo(psis_output, pointwise_lls)
goedman commented 3 years ago

Thanks Chris, will try to switch to your setup for these tests.

The crucial line in compute_loo() is where I multiply the log_prob with the log_weights ( lwp += psis_output.weights ).

goedman commented 3 years ago

Chris, works fine. Any chance you could do the cars example in Turing. If not, I will attempt it early next week.

Your example gives me a loo of 190.

goedman commented 3 years ago

Carlos, you probably understand this much better than I do, but both PSIS and Waic are to compare models, so a constant offset in the values might be less important?

ParadaCarleton commented 3 years ago

I'm currently adding helper functions to compute LOO values, along with MCSE and other related quantities provided by the R loo package; these will go in another package, since some uses of PSIS are not related to LOO-CV (e.g. improving variational approximations).