TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Retrieving samples of unconstrained variables #1494

Closed sethaxen closed 2 years ago

sethaxen commented 3 years ago

As I understand it, given a Turing Model, we can get a bijector, which transforms the potentially constrained variables in the model to some unconstrained latent space. Sampling takes place in this latent space, while samples are returned in a Chains object using the original constrained parameterization. Sometimes, it's easier to diagnose sampling issues by working with the variables in the latent space. It would be nice to have a way to, given a Model and corresponding Chains with samples of the constrained variables, get a new Chains that contains samples of the corresponding latent variables.

cc @cpfiffer

devmotion commented 3 years ago

The sampling space is not a property of the Model and neither encoded nor known at this stage. This is defined by every sampler individually.

cpfiffer commented 3 years ago

Yeah, but this issue speaks to the fact that we don't have a top-level understanding of when and where variables are transformed when they get to the user, which I think we should have.

devmotion commented 3 years ago

Yes, I agree. Everytime I implement or fix something that uses these transformations I'm confused since it happens implicitly when indexing the VarInfo.

devmotion commented 3 years ago

Therefore I am also curious when it is helpful for users to obtain the internally sampled values - you have to know the internals of your sampler to analyze them whereas samples in the original space are clearly interpretable by only knowing your model.

torfjelde commented 3 years ago

Also, it should def be possible to write a function which does what Seth wants. The following works on my end.

A couple of utility functions for creating VarInfo from a MCMCChains.Chains:

import Turing.DynamicPPL

# TODO: make generated for `TypedVarInfo`
function apply_varnames!(f!, vi::DynamicPPL.TypedVarInfo, spl::DynamicPPL.AbstractSampler)
    vns = DynamicPPL._getvns(vi, spl)
    for v in keys(vns)
        for vn in vns[v]
            f!(vi, vn)
        end
    end
end

function varinfo_from_chain!(vi::DynamicPPL.AbstractVarInfo, model, chain, sample_idx = 1, chain_idx = 1)
    # Update the parameters
    DynamicPPL.setval!(vi, chain, sample_idx, chain_idx)

    # Update the logjoint accordingly
    lp = first(chain[sample_idx, :lp, chain_idx])
    DynamicPPL.setlogp!(vi, lp)

    return vi
end

function varinfos_from_chain(model, chain, spl)
    vi_base = DynamicPPL.VarInfo(model, DynamicPPL.initialsampler(spl))

    return map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (sample_idx, chain_idx)
        vi = deepcopy(vi_base)
        return varinfo_from_chain!(vi, model, chain, sample_idx, chain_idx)
    end
end

Assuming the chain is already in constrained space, we can then retrieve the unconstrained chain as follows:

function unconstrain(chain::MCMCChains.Chains, model, spl)
    vis = map(varinfos_from_chain(model, chain, spl)) do vi
        # Transform parameters to unconstrained
        DynamicPPL.link!(vi, spl)

        # Make it so that `tonamedtuple` won't transform the variables
        # if we want to convert into a chain again.
        apply_varnames!(vi, spl) do vi, vn
            DynamicPPL.settrans!(vi, false, vn)
        end

        return vi
    end

    # Combine the unconstrained chains into one `MCMCChains.Chains`
    chain_unconstrained = reduce(
        MCMCChains.chainscat, [
            AbstractMCMC.bundle_samples(vis[:, chain_idx], model, spl, nothing, MCMCChains.Chains)
            for chain_idx in MCMCChains.chains(chain)
        ]
    )

    # Combine with internal parameters from original chain
    return hcat(
        MCMCChains.get_sections(chain_unconstrained, :parameters),
        MCMCChains.get_sections(chain, :internals)
    )
end

Example:

@model function gdemo(xs)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, √s)
    for i in eachindex(xs)
        xs[i] ~ Normal(m, √s)
    end
end

m = gdemo(randn(100) .+ 1);
alg = NUTS(0.65)
c = sample(m, alg, 1000);
c
Chains MCMC chain (1000×14×1 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = m, s
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

           m    0.9083    0.0959     0.0030    0.0039   893.4550    1.0015
           s    0.9205    0.1264     0.0040    0.0064   822.6878    0.9990

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           m    0.7151    0.8442    0.9116    0.9717    1.1076
           s    0.7100    0.8298    0.9098    1.0026    1.2282

And

c_unconstrained = unconstrain(c, m, DynamicPPL.Sampler(alg))
Chains MCMC chain (1000×14×1 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = m, s
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

           m    0.9083    0.0959     0.0030    0.0039   893.4550    1.0015
           s   -0.0920    0.1356     0.0043    0.0069   832.0576    0.9990

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           m    0.7151    0.8442    0.9116    0.9717    1.1076
           s   -0.3425   -0.1865   -0.0946    0.0026    0.2055

which is exactly what we want.

Might also be a good idea to add something to c_unconstrained.info indicating that it's transformed using MCMCChains.setinfo(c_unconstrained, (transformed = true, )) or something.

yebai commented 2 years ago

Close in favour of https://github.com/TuringLang/DynamicPPL.jl/issues/94 and https://github.com/TuringLang/DynamicPPL.jl/pull/347