Open ElOceanografo opened 5 days ago
This looks neat 👀
Can you elaborate a bit on what mathematical model the following line should correspond to?
marginalmodel = marginalize(fullmodel, (:x,))
Would this marginalize out x
? And exactly which terms in the posterior are when replacing with the Laplace approx? My Laplace is a bit rusty 😬
(And MarginalLogDensities.jl looks really nice btw. If people are interested in model comparison, this will probably be a place we'll point people:))
IIUC you effectively want marginalize(...)
to result in
$$p(y, a, b) = \int p(y, a, b, x) dx = \bigg( \int p(y \mid x) \times p(x \mid a, b) dx \bigg) \times p(a) \times p(b)$$
with the Laplace approx. replacing the integrand on the RHS?
But does this mean that you want to drop both the terms $p(y \mid x)$ and $p(x \mid a, b)$ in the original joint model, and replace those with a simple addlogprob!
call? I.e.
@model function MarginalizedExample()
a ~ SomeDist()
b ~ AnotherDist()
Turing.@addlogprob! MarginalLogDensity(...)
end
?
Because that won't be possible with Turing.jl unfortuantely 😕 At least not in a completely automated fashion.
Yes, your math is correct. If f(u)
is a function that returns the log-posterior, and u = [a; b; x]
, then calling mld = MarginalLogDensity(f, u, i_x, y)
makes a callable struct. When you call mld(v)
(where v
corresponds to [a; b] == u[setdiff(eachindex(u), i_x)]
), it will find the mode of $p(x | a, b, y)$ and the curvature at that mode, and use that Hessian to calculate the LA and integrate out $x$.
To do this using @addlogprob!
, you'd need to write your own log-density function (as a function of all the parameters and data), wrap it in a MarginalLogDensity
to integrate out $x$, and pass that into a Turing model, similar to my first example. That's basically possible now, but requires writing the log-posterior function from scratch.
To do it the other way, starting with a Turing model, my hypothetical marginalize
function would do something like:
DynamicPPL.Model
to something that takes a flat parameter vector and returns a log-probability (a LogDensityFunction
? OptimLogDensity
? Something else?)MarginalLogDensity
, or maybe a new TuringMarginalLogDensity
type that stores some extra metadatasample
, maximum_a_posteriori
, etc.All this should be basically possible, no?
So I do believe this is achieveable in some models, but the problem is that I think it'll be difficult to verify that the user is actually doing the right thing :confused:
What I'm worried about is that both the target variable and the likelihood is absorbed into the MarginalLogDensity
, right? So, taking your example,
@model function Example(y)
a ~ SomeDist()
b ~ AnotherDist()
mu = somefunction(a, b)
x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)
end
should become
@model function MarginalizedExample()
a ~ SomeDist()
b ~ AnotherDist()
Turing.@addlogprob! MarginalLogDensity(...)
end
Correct?
However, this is not something we can do automatically in Turing.jl. In Turing.jl, your model always looks like the original definition, i.e.
@model function Example(y)
a ~ SomeDist()
b ~ AnotherDist()
mu = somefunction(a, b)
x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)
end
but you alter the model by changing the behavior of each of the individual ~
statements; no "global" transformations are allowed. These individual ~
statements are altered by overloading methods such as x, __varinfo__ = DynamicPPL.tilde_assume(...)
and y, __varinfo__ = DynamicPPL.tilde_observe(...)
corresponding to random and observed variables, respectively (__varinfo__
is the structure we use to keep track of variables internally at runtime).
Hence, to "convert"
x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)
into
Turing.@addlogprob! MarginalLogDensity(...)
automatically, we need overload tilde_assume
for x
and tilde_observe
for y
to only perform this addlogprob!
once.
We can technically do this. For example, we could just have x, ... = tilde_assume(...)
just not do anything really, i.e. just sample from the prior and return 0
for the logpdf
, and then we move the entire Turing.@addlogprob! MarginalLogDensity(...)
to the statement y, ... = tilde_observe(...)
. This would, in this particular example, work just fine.
However, the question is: how do we know whether this is an okay thing to do?
For example, IIUC, the variable(s) being marginalised need to be a parent of the observation variable in the DAG. For example, in the above case, we cannot marginalise out only a
because x
is on the path a -> y
, and so we would have to also marginalise out x
.
The problem is that Turing.jl doesn't have (currently) have access to the DAG of the model, and so performing such checks is not really possible atm :confused:
I guess I'm not clear why any of the internal details of the model need apply here...if a Turing model can be turned into a black-box function $p([a, b, x] | y)$ for the purposes of optimization or sampling, then the same function should work in a MarginalLogDensity
.
Here's a concrete example of what I'm suggesting:
using Turing
using Distributions
using MarginalLogDensities
import Zygote
a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))
@model function demo(y, s)
a ~ Normal(0, 10)
b ~ Normal(0, 10)
logc ~ Normal()
c = exp(logc)
μ = a.*s.^2 .+ b.*s
x ~ MvNormal(μ, c)
y ~ MvNormal(x, 3.0)
end
m = demo(y, s)
ctx = Turing.Optimisation.OptimizationContext(Turing.DefaultContext())
old = Turing.Optimisation.OptimLogDensity(m, ctx)
# marginalize the `x` variables, indices 4:24
mld = MarginalLogDensity((u, p) -> -old(u), randn(24), 4:24, (), LaplaceApprox(adtype=AutoZygote()))
mld(randn(3))
I guess I'm not clear why any of the internal details of the model need apply here
Ah sorry, that probably wasn't clear from my end :grimacing:
I was thinking of more convoluted scenarios, e.g.
@model function demo()
s ~ InverseGamma(2, 3)
x ~ Normal(0, sqrt(s))
y ~ Normal(0, sqrt(s))
return (s, x)
end
model = demo() | (y = 1.0,) # model over `s` and `x`
model() # => (s = rand(InverseGamma(2, 3)), x = rand(Normal(0, sqrt(s)))
marginalized_model = marginalize(model, @varname(x)) # model is now only over `s`
marginalized_model() # => (s = rand(InverseGamma(2, 3)), x = argmax(logjoint(model, x)))
where argmax(logjoint(model, x))
is representing the MAP of the posterior. Similarly, log-probability computations, e.g. logjoint
would also be altered to include the normalisation constant from the Laplace approx.
Does that make sense? It was a bit unclear to me how to do this + whether someone would actually want to do this.
Turing model can be turned into a black-box function
But yes, this would be very easy to do:)
using Turing, MarginalLogDensities, LogDensityProblems, Zygote
"""
marginalize(model::Model, varnames::Vector; method=LaplaceApprox())
Returns a `MarginalLogDensity` with `varnames` marginalized out from the `model`.
"""
function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
varindices = DynamicPPL.getranges(varinfo, varnames)
# Construct the marginal log-density model.
# TODO(torfjelde): Should link and use optimization context to avoid inclusion jacobian corrections to the log-density.
f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model))
mdl = MarginalLogDensity((u, p) -> f(u), varinfo[:], varindices, (), method)
return mdl
end
a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))
@model function demo(y, s)
a ~ Normal(0, 10)
b ~ Normal(0, 10)
logc ~ Normal()
c = exp(logc)
μ = a.*s.^2 .+ b.*s
x ~ MvNormal(μ, c)
y ~ MvNormal(x, 3.0)
end
model = demo(y, s);
mdl = marginalize(model, [@varname(x)]; method=LaplaceApprox(adtype=AutoZygote()));
mdl(randn(3))
With the "bugfix" below, this currently works like a charm:) Though it should probably use the optimization context + linking, i.e.
function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
varindices = DynamicPPL.getranges(varinfo, varnames)
# Construct the marginal log-density model.
# Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure
# that the log-abs-det jacobian terms are not included.
context = Turing.Optimisation.OptimizationContext(DynamicPPL.leafcontext(model.context))
varinfo_linked = DynamicPPL.link(varinfo, model)
f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(varinfo_linked, model, context))
# HACK: need the sign-flip here because `OptimizationContext` is a hacky impl :/
mdl = MarginalLogDensity((u, p) -> -f(u), varinfo_linked[:], varindices, (), method)
return mdl
end
but that requires depending / an extension using Turing.jl instead of just DynamicPPL.jl (the package which defines the modeling syntax, etc.).
At the moment, this requires a "bugfix" to DynamicPPL.jl (I'm using an internal function in a way that isn't quite meant to be used this way currently):
# FIXME(torfjelde): This is an internal function that isn't CURRENTLY meant to be used in this way,
# but we can add the following missing def to make it work.
# It takes in a vector of variable names and returns the corresponding indices.
function DynamicPPL.getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})
# Here we need to keep track of the offset.
offset = 0
vns_all = DynamicPPL.keys(varinfo)
return mapreduce(vcat, vns_all; init=Int[]) do vn
# First we need to get the range so we can add it to the total offset.
r = DynamicPPL.getrange(varinfo, vn)
length_vn = length(r)
# Then we check if `vn` is one of the variables we're extracting the ranges for.
# TODO(torfjelde): Maybe use `findall` + `subsumes` instead?
index = findfirst(isequal(vn), vns)
return if index === nothing
# If none exist, we shift the offset and return an empty array.
offset += length_vn
Int[]
else
# Otherwise, we return the (offseted) range and update the offset.
r = r .+ offset
offset += length_vn
r
end
end
end
Btw, it's quite unfortunate that we have to use Zygote.jl here to get 2nd order information :confused: It's really not well-suited for Turing.jl models.
Oh, and if you wanted to use this for sampling, you could just add the following:
#########################
## To sample from this ##
#########################
# 1. Add LogDensityProblems.jl interface to it and we can suddenly use samplers.
LogDensityProblems.logdensity(mdl::MarginalLogDensity, u) = mdl(u)
LogDensityProblems.dimension(mdl::MarginalLogDensity) = length(mdl.iv)
function LogDensityProblems.capabilities(mdl::MarginalLogDensity)
return LogDensityProblems.LogDensityOrder{0}()
end
# 2. Unfortunately, we have to use the sampler packages explicitly.
using AbstractMCMC, AdvancedMH
spl = AdvancedMH.RWMH(LogDensityProblems.dimension(mdl))
samples = sample(
mdl, spl, 1000;
chain_type=MCMCChains.Chains,
# HACK: this a dirty way to extract the variable names in a model; it won't work in general.
# But general methods exist, so we can fix that.
param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
)
Results in
julia> samples = sample(
mdl, spl, 1000;
chain_type=MCMCChains.Chains,
# HACK: this a dirty way to extract the variable names in a model.
param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
)
Sampling 100%|█████████████████████████████████████████████████████████████| Time: 0:00:03
Chains MCMC chain (1000×4×1 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
parameters = a, b, logc
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Missing
a 0.6280 0.3893 0.0576 40.8026 18.2906 1.4236 missing
b -15.2641 4.2187 2.0435 4.2878 4.6395 1.4834 missing
logc 1.0824 1.1025 0.5967 4.0816 4.4534 1.4062 missing
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -0.6893 0.6643 0.6970 0.6970 1.3287
b -17.2110 -17.1433 -17.1433 -16.4625 -2.2291
logc 0.4505 0.4505 0.4505 1.2029 3.7519
Cool, thank you for the clarification, I knew I was missing something! I would not have thought to try using any of the fancy conditioning syntax. Making a Turing model into a MarginalizedLogDensity
is very analogous to making it into an OptimizationProblem
, and I don't think users would expect any of the more advanced conditioning stuff to work.
And thanks for the worked example. I was trying to figure out how to get the variable names and indices automatically but couldn't do it myself. This feels like a good outline of a usable extension...does it make more sense as an MLD extension in Turing, or a Turing extension in MLD?
Btw, it's quite unfortunate that we have to use Zygote.jl here to get 2nd order information 😕 It's really not well-suited for Turing.jl models.
We shouldn't need to, MarginalLogDensities is totally backend-agnostic (in theory anyway).
I was trying to figure out how to get the variable names and indices automatically but couldn't do it myself. This feels like a good outline of a usable extension...does it make more sense as an MLD extension in Turing, or a Turing extension in MLD?
I think, given that it returns a MLD object, that it makes more sense as an MLD extension maybe?
And yes, I do think this is pretty close to being a functional extension:) If you want, I'm happy to make a PR or you can make a PR and I can support?
The only annoyance is that you end up touching some internals, e.g. OptimizationContext
, which means that the extension could technically break even when Turing.jl makes non-breaking releases :confused:
We shouldn't need to, MarginalLogDensities is totally backend-agnostic (in theory anyway).
Yeah no I saw this, which is very nice:) It was more a comment on the AD packages; I tried both ForwardDiff.jl and ReverseDiff.jl, but both failed upon hessian computation.
Thanks, @ElOceanografo; I will keep it as a Turing extension if you like. Please feel free to open a PR.
It is helpful for an ongoing research project for modular inference methods.
I think, given that it returns a MLD object, that it makes more sense as an MLD extension maybe?
That's kind of what I was thinking as well. If there are ways to make the interface a bit more robust to changes in those Turing internals that would be great, but MLD is an experimental package with few users at the moment, so I'm okay with a moderate risk of breakage. If you want to make a PR go for it, otherwise I will try to get to it in the next week or so.
Yeah no I saw this, which is very nice:) It was more a comment on the AD packages; I tried both ForwardDiff.jl and ReverseDiff.jl, but both failed upon hessian computation.
Try again with ForwardDiff, there was a bug that just got fixed here: https://github.com/ElOceanografo/MarginalLogDensities.jl/pull/36. But in general yes, things are a bit more brittle in practice than they should be in theory...
If there are ways to make the interface a bit more robust to changes in those Turing internals that would be great,
Yeah, maybe it's better to do it as an extension to Turing.jl given that it access Turing.jl internals :thinking: I think the likelihood of us breaking the functionality is greater than you doing so, so seems sensible to put it in Turing.jl (at least for now).
I'll open a PR with some minor tweaks to the above then :+1: Buuut will likely need some help to get it merged (a bit limited on time these days); maybe you could help add some tests @ElOceanografo ?:)
Yeah, definitely. Just ping me when you've got a PR and let me know where you'd like the tests to go.
Moving a discussion with @yebai from Slack to here. @PavanChaggar asked if there was a way to do Laplace approximation in Turing, and I gave this little example of how it could be accomplished with MarginalLogDensities.jl:
This issue is to discuss how this capability might be integrated better into Turing, probably via a package extension. (See also https://github.com/TuringLang/Turing.jl/issues/1382, which I opened before I made MLD). From a user perspective, an interface like this makes sense to me:
I think there are two basic ways to implement this:
marginalize
constructs a new, marginalizedDynamicPPL.Model
, orMarginalLogDensities.MarginalLogDensity
, with new methods forsample
,maximum_a_posteriori
, andmaximum_likelihood
defined for it. I'm not very familiar with Turing's internals, so happy to be corrected if there are other approaches that make more sense....The other current roadblock is making calls to
MarginalLogDensity
objects differentiable (https://github.com/ElOceanografo/MarginalLogDensities.jl/issues/34). This is doable, I just need to do it.