TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.05k stars 219 forks source link

Integrating Turing and MarginalLogDensities #2398

Open ElOceanografo opened 5 days ago

ElOceanografo commented 5 days ago

Moving a discussion with @yebai from Slack to here. @PavanChaggar asked if there was a way to do Laplace approximation in Turing, and I gave this little example of how it could be accomplished with MarginalLogDensities.jl:

using MarginalLogDensities
using ReverseDiff, FiniteDiff
using Turing

# simple density function
f(x, p) = -sum(abs2, diff(x))
# define a MarginalLogDensity that integrates out parameters 5 thru 100
mld = MarginalLogDensity(f, rand(100), 5:100, (), LaplaceApprox(adtype=AutoReverseDiff()))

@model function MLDTest(mld)
    θ ~ MvNormal(zeros(njoint(mld)), 1.0)
    Turing.Turing.@addlogprob! mld(θ)
end

mod = MLDTest(mld)
chn = sample(mod, NUTS(adtype = AutoFiniteDiff()), 100)

This issue is to discuss how this capability might be integrated better into Turing, probably via a package extension. (See also https://github.com/TuringLang/Turing.jl/issues/1382, which I opened before I made MLD). From a user perspective, an interface like this makes sense to me:

@model function Example(y)
    a ~ SomeDist()
    b ~ AnotherDist()
    mu = somefunction(a, b)
    x ~ MvNormal(mu, 1.0)
    y ~ MvNormal(x, 1.0)
end

fullmodel = Example(ydata)
marginalmodel = marginalize(fullmodel, (:x,))
sample(marginalmodel, NUTS(), 1000)
maximum_a_posteriori(marginalmodel, LBFGS())

I think there are two basic ways to implement this:

The other current roadblock is making calls to MarginalLogDensity objects differentiable (https://github.com/ElOceanografo/MarginalLogDensities.jl/issues/34). This is doable, I just need to do it.

torfjelde commented 4 days ago

This looks neat 👀

Can you elaborate a bit on what mathematical model the following line should correspond to?

marginalmodel = marginalize(fullmodel, (:x,))

Would this marginalize out x? And exactly which terms in the posterior are when replacing with the Laplace approx? My Laplace is a bit rusty 😬

(And MarginalLogDensities.jl looks really nice btw. If people are interested in model comparison, this will probably be a place we'll point people:))

torfjelde commented 4 days ago

IIUC you effectively want marginalize(...) to result in $$p(y, a, b) = \int p(y, a, b, x) dx = \bigg( \int p(y \mid x) \times p(x \mid a, b) dx \bigg) \times p(a) \times p(b)$$ with the Laplace approx. replacing the integrand on the RHS?

But does this mean that you want to drop both the terms $p(y \mid x)$ and $p(x \mid a, b)$ in the original joint model, and replace those with a simple addlogprob! call? I.e.

@model function MarginalizedExample()
    a ~ SomeDist()
    b ~ AnotherDist()
    Turing.@addlogprob! MarginalLogDensity(...)
end

?

Because that won't be possible with Turing.jl unfortuantely 😕 At least not in a completely automated fashion.

ElOceanografo commented 2 days ago

Yes, your math is correct. If f(u) is a function that returns the log-posterior, and u = [a; b; x], then calling mld = MarginalLogDensity(f, u, i_x, y) makes a callable struct. When you call mld(v) (where v corresponds to [a; b] == u[setdiff(eachindex(u), i_x)]), it will find the mode of $p(x | a, b, y)$ and the curvature at that mode, and use that Hessian to calculate the LA and integrate out $x$.

To do this using @addlogprob!, you'd need to write your own log-density function (as a function of all the parameters and data), wrap it in a MarginalLogDensity to integrate out $x$, and pass that into a Turing model, similar to my first example. That's basically possible now, but requires writing the log-posterior function from scratch.

To do it the other way, starting with a Turing model, my hypothetical marginalize function would do something like:

  1. Convert the DynamicPPL.Model to something that takes a flat parameter vector and returns a log-probability (a LogDensityFunction ? OptimLogDensity? Something else?)
  2. Get the indices of the variables to marginalize within that flat parameter vector
  3. Use 1 and 2 to construct a MarginalLogDensity, or maybe a new TuringMarginalLogDensity type that stores some extra metadata
  4. Pass that to specialized methods for sample, maximum_a_posteriori, etc.
  5. Do any extra bookkeeping needed at the end to report the report the parameters with their correct names, etc.

All this should be basically possible, no?

torfjelde commented 2 days ago

So I do believe this is achieveable in some models, but the problem is that I think it'll be difficult to verify that the user is actually doing the right thing :confused:

Approach to implementing this

What I'm worried about is that both the target variable and the likelihood is absorbed into the MarginalLogDensity, right? So, taking your example,

@model function Example(y)
    a ~ SomeDist()
    b ~ AnotherDist()
    mu = somefunction(a, b)
    x ~ MvNormal(mu, 1.0)
    y ~ MvNormal(x, 1.0)
end

should become

@model function MarginalizedExample()
    a ~ SomeDist()
    b ~ AnotherDist()
    Turing.@addlogprob! MarginalLogDensity(...)
end

Correct?

However, this is not something we can do automatically in Turing.jl. In Turing.jl, your model always looks like the original definition, i.e.

@model function Example(y)
    a ~ SomeDist()
    b ~ AnotherDist()
    mu = somefunction(a, b)
    x ~ MvNormal(mu, 1.0)
    y ~ MvNormal(x, 1.0)
end

but you alter the model by changing the behavior of each of the individual ~ statements; no "global" transformations are allowed. These individual ~ statements are altered by overloading methods such as x, __varinfo__ = DynamicPPL.tilde_assume(...) and y, __varinfo__ = DynamicPPL.tilde_observe(...) corresponding to random and observed variables, respectively (__varinfo__ is the structure we use to keep track of variables internally at runtime).

Hence, to "convert"

x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)

into

Turing.@addlogprob! MarginalLogDensity(...)

automatically, we need overload tilde_assume for x and tilde_observe for y to only perform this addlogprob! once.

We can technically do this. For example, we could just have x, ... = tilde_assume(...) just not do anything really, i.e. just sample from the prior and return 0 for the logpdf, and then we move the entire Turing.@addlogprob! MarginalLogDensity(...) to the statement y, ... = tilde_observe(...). This would, in this particular example, work just fine.

Is this okay?

However, the question is: how do we know whether this is an okay thing to do?

For example, IIUC, the variable(s) being marginalised need to be a parent of the observation variable in the DAG. For example, in the above case, we cannot marginalise out only a because x is on the path a -> y, and so we would have to also marginalise out x.

The problem is that Turing.jl doesn't have (currently) have access to the DAG of the model, and so performing such checks is not really possible atm :confused:

ElOceanografo commented 1 day ago

I guess I'm not clear why any of the internal details of the model need apply here...if a Turing model can be turned into a black-box function $p([a, b, x] | y)$ for the purposes of optimization or sampling, then the same function should work in a MarginalLogDensity.

Here's a concrete example of what I'm suggesting:

using Turing
using Distributions
using MarginalLogDensities
import Zygote

a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))

@model function demo(y, s)
    a ~ Normal(0, 10)
    b ~ Normal(0, 10)
    logc ~ Normal()
    c = exp(logc)
    μ = a.*s.^2 .+ b.*s
    x ~ MvNormal(μ, c)
    y ~ MvNormal(x, 3.0)
end

m = demo(y, s)
ctx = Turing.Optimisation.OptimizationContext(Turing.DefaultContext())
old = Turing.Optimisation.OptimLogDensity(m, ctx)

# marginalize the `x` variables, indices 4:24
mld = MarginalLogDensity((u, p) -> -old(u), randn(24), 4:24, (), LaplaceApprox(adtype=AutoZygote())) 
mld(randn(3))
torfjelde commented 20 hours ago

I guess I'm not clear why any of the internal details of the model need apply here

Ah sorry, that probably wasn't clear from my end :grimacing:

I was thinking of more convoluted scenarios, e.g.

@model function demo()
    s ~ InverseGamma(2, 3)
    x ~ Normal(0, sqrt(s))
    y ~ Normal(0, sqrt(s))

    return (s, x)
end

model = demo() | (y = 1.0,)  # model over `s` and `x`
model()  # => (s = rand(InverseGamma(2, 3)), x = rand(Normal(0, sqrt(s)))

marginalized_model = marginalize(model, @varname(x))  # model is now only over `s`
marginalized_model()  # => (s = rand(InverseGamma(2, 3)), x = argmax(logjoint(model, x)))

where argmax(logjoint(model, x)) is representing the MAP of the posterior. Similarly, log-probability computations, e.g. logjoint would also be altered to include the normalisation constant from the Laplace approx.

Does that make sense? It was a bit unclear to me how to do this + whether someone would actually want to do this.

Turing model can be turned into a black-box function

But yes, this would be very easy to do:)

using Turing, MarginalLogDensities, LogDensityProblems, Zygote

"""
    marginalize(model::Model, varnames::Vector; method=LaplaceApprox())

Returns a `MarginalLogDensity` with `varnames` marginalized out from the `model`.
"""
function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
    # Determine the indices for the variables to marginalise out.
    varinfo = DynamicPPL.typed_varinfo(model)
    varindices = DynamicPPL.getranges(varinfo, varnames)
    # Construct the marginal log-density model.
    # TODO(torfjelde): Should link and use optimization context to avoid inclusion jacobian corrections to the log-density.
    f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model))
    mdl = MarginalLogDensity((u, p) -> f(u), varinfo[:], varindices, (), method)
    return mdl
end

a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))

@model function demo(y, s)
    a ~ Normal(0, 10)
    b ~ Normal(0, 10)
    logc ~ Normal()
    c = exp(logc)
    μ = a.*s.^2 .+ b.*s
    x ~ MvNormal(μ, c)
    y ~ MvNormal(x, 3.0)
end

model = demo(y, s);
mdl = marginalize(model, [@varname(x)]; method=LaplaceApprox(adtype=AutoZygote()));
mdl(randn(3))

With the "bugfix" below, this currently works like a charm:) Though it should probably use the optimization context + linking, i.e.

function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
    # Determine the indices for the variables to marginalise out.
    varinfo = DynamicPPL.typed_varinfo(model)
    varindices = DynamicPPL.getranges(varinfo, varnames)
    # Construct the marginal log-density model.
    # Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure
    # that the log-abs-det jacobian terms are not included.
    context = Turing.Optimisation.OptimizationContext(DynamicPPL.leafcontext(model.context))
    varinfo_linked = DynamicPPL.link(varinfo, model)
    f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(varinfo_linked, model, context))
    # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl :/
    mdl = MarginalLogDensity((u, p) -> -f(u), varinfo_linked[:], varindices, (), method)
    return mdl
end

but that requires depending / an extension using Turing.jl instead of just DynamicPPL.jl (the package which defines the modeling syntax, etc.).

At the moment, this requires a "bugfix" to DynamicPPL.jl (I'm using an internal function in a way that isn't quite meant to be used this way currently):

# FIXME(torfjelde): This is an internal function that isn't CURRENTLY meant to be used in this way,
# but we can add the following missing def to make it work.
# It takes in a vector of variable names and returns the corresponding indices.
function DynamicPPL.getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})
    # Here we need to keep track of the offset.
    offset = 0
    vns_all = DynamicPPL.keys(varinfo)
    return mapreduce(vcat, vns_all; init=Int[]) do vn
        # First we need to get the range so we can add it to the total offset.
        r = DynamicPPL.getrange(varinfo, vn)
        length_vn = length(r)
        # Then we check if `vn` is one of the variables we're extracting the ranges for.
        # TODO(torfjelde): Maybe use `findall` + `subsumes` instead?
        index = findfirst(isequal(vn), vns)
        return if index === nothing
            # If none exist, we shift the offset and return an empty array.
            offset += length_vn
            Int[]
        else
            # Otherwise, we return the (offseted) range and update the offset.
            r = r .+ offset
            offset += length_vn
            r
        end
    end
end
torfjelde commented 19 hours ago

Btw, it's quite unfortunate that we have to use Zygote.jl here to get 2nd order information :confused: It's really not well-suited for Turing.jl models.

torfjelde commented 19 hours ago

Oh, and if you wanted to use this for sampling, you could just add the following:

#########################
## To sample from this ##
#########################
# 1. Add LogDensityProblems.jl interface to it and we can suddenly use samplers.
LogDensityProblems.logdensity(mdl::MarginalLogDensity, u) = mdl(u)
LogDensityProblems.dimension(mdl::MarginalLogDensity) = length(mdl.iv)
function LogDensityProblems.capabilities(mdl::MarginalLogDensity)
    return LogDensityProblems.LogDensityOrder{0}()
end

# 2. Unfortunately, we have to use the sampler packages explicitly.
using AbstractMCMC, AdvancedMH
spl = AdvancedMH.RWMH(LogDensityProblems.dimension(mdl))
samples = sample(
    mdl, spl, 1000;
    chain_type=MCMCChains.Chains,
    # HACK: this a dirty way to extract the variable names in a model; it won't work in general.
    # But general methods exist, so we can fix that.
    param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
)

Results in

julia> samples = sample(
           mdl, spl, 1000;
           chain_type=MCMCChains.Chains,
           # HACK: this a dirty way to extract the variable names in a model.
           param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
       )
Sampling 100%|█████████████████████████████████████████████████████████████| Time: 0:00:03
Chains MCMC chain (1000×4×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = a, b, logc
internals         = lp

Summary Statistics
  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol    Float64   Float64   Float64    Float64    Float64   Float64       Missing 

           a     0.6280    0.3893    0.0576    40.8026    18.2906    1.4236       missing
           b   -15.2641    4.2187    2.0435     4.2878     4.6395    1.4834       missing
        logc     1.0824    1.1025    0.5967     4.0816     4.4534    1.4062       missing

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%     97.5% 
      Symbol    Float64    Float64    Float64    Float64   Float64 

           a    -0.6893     0.6643     0.6970     0.6970    1.3287
           b   -17.2110   -17.1433   -17.1433   -16.4625   -2.2291
        logc     0.4505     0.4505     0.4505     1.2029    3.7519
ElOceanografo commented 18 hours ago

Cool, thank you for the clarification, I knew I was missing something! I would not have thought to try using any of the fancy conditioning syntax. Making a Turing model into a MarginalizedLogDensity is very analogous to making it into an OptimizationProblem, and I don't think users would expect any of the more advanced conditioning stuff to work.

And thanks for the worked example. I was trying to figure out how to get the variable names and indices automatically but couldn't do it myself. This feels like a good outline of a usable extension...does it make more sense as an MLD extension in Turing, or a Turing extension in MLD?

ElOceanografo commented 18 hours ago

Btw, it's quite unfortunate that we have to use Zygote.jl here to get 2nd order information 😕 It's really not well-suited for Turing.jl models.

We shouldn't need to, MarginalLogDensities is totally backend-agnostic (in theory anyway).

torfjelde commented 17 hours ago

I was trying to figure out how to get the variable names and indices automatically but couldn't do it myself. This feels like a good outline of a usable extension...does it make more sense as an MLD extension in Turing, or a Turing extension in MLD?

I think, given that it returns a MLD object, that it makes more sense as an MLD extension maybe?

And yes, I do think this is pretty close to being a functional extension:) If you want, I'm happy to make a PR or you can make a PR and I can support?

The only annoyance is that you end up touching some internals, e.g. OptimizationContext, which means that the extension could technically break even when Turing.jl makes non-breaking releases :confused:

We shouldn't need to, MarginalLogDensities is totally backend-agnostic (in theory anyway).

Yeah no I saw this, which is very nice:) It was more a comment on the AD packages; I tried both ForwardDiff.jl and ReverseDiff.jl, but both failed upon hessian computation.

yebai commented 17 hours ago

Thanks, @ElOceanografo; I will keep it as a Turing extension if you like. Please feel free to open a PR.

It is helpful for an ongoing research project for modular inference methods.

ElOceanografo commented 17 hours ago

I think, given that it returns a MLD object, that it makes more sense as an MLD extension maybe?

That's kind of what I was thinking as well. If there are ways to make the interface a bit more robust to changes in those Turing internals that would be great, but MLD is an experimental package with few users at the moment, so I'm okay with a moderate risk of breakage. If you want to make a PR go for it, otherwise I will try to get to it in the next week or so.

Yeah no I saw this, which is very nice:) It was more a comment on the AD packages; I tried both ForwardDiff.jl and ReverseDiff.jl, but both failed upon hessian computation.

Try again with ForwardDiff, there was a bug that just got fixed here: https://github.com/ElOceanografo/MarginalLogDensities.jl/pull/36. But in general yes, things are a bit more brittle in practice than they should be in theory...

torfjelde commented 11 hours ago

If there are ways to make the interface a bit more robust to changes in those Turing internals that would be great,

Yeah, maybe it's better to do it as an extension to Turing.jl given that it access Turing.jl internals :thinking: I think the likelihood of us breaking the functionality is greater than you doing so, so seems sensible to put it in Turing.jl (at least for now).

I'll open a PR with some minor tweaks to the above then :+1: Buuut will likely need some help to get it merged (a bit limited on time these days); maybe you could help add some tests @ElOceanografo ?:)

ElOceanografo commented 7 hours ago

Yeah, definitely. Just ping me when you've got a PR and let me know where you'd like the tests to go.