aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 21 forks source link

Marginalize over latent variables #21

Open ricardoV94 opened 3 years ago

ricardoV94 commented 3 years ago

I wonder whether it would be possible to rewrite the logp graphs to marginalize over finite discrete variables, indicated by the user (not necessarily all that are in the graph).

x_rv = at.random.bernoulli(0.7)
y_rv = at.normal([0, 1], [1, 1])[x_rv]
y = y_rv.type()

logp = joint_logprob(y_rv, {y_rv: y}, marginalize={x_rv})

Whose logp would be p(y_rv=y | x_rv=0) * p(x_rv=0) + p(y_rv=y | x_rv=1) * p(x_rv=1)

This is straightforward(ish) if the marginalization happens just above the requested variable (e.g., y_rv), but gets more complicated if it happens at the top of a deeper graph.

rlouf commented 3 years ago

I'd be happy to have a look at this once I am done with NUTS in aehmc.

ricardoV94 commented 2 years ago

I have been thinking a bit about how we could implement marginalization of variables in aeppl. The obvious thing is that we won't be implementing all possible symbolic marginalizations (or distributions) out there, so we need to have an API that is easy to interact with.

My current idea is that aeppl figures out what RVs have to be marginalized (e.g., because they lack a Value variable, or are given a marginal flag value in the rv_values dict) and wraps those in a pseudo Op Marginal that has no logprob defined for it by default.

mu_rv = Normal(0, 1)
y_rv = Normal(mu_rv, 1) 
y_vv = y_rv.clone()

factorized_joint_logprob({y_rv: y_vv})  # mu is not given a value

# aesara rewrites intermediately as 
mu_rv = Marginal(Normal(0, 1))
y_rv = Normal(mu_rv, 1)

Which can be easily matched by a PatternSub rewrite for instance:

PatternSub (
  (Normal, "rng1", "size", "dtype", (Marginal, ( Normal, ("rng2", "size", "dtype", "mu", "sigma2"), "sigma1",)), 
  (Normal, "rng1", "size", "dtype", "mu", (sqrt, (add, (sqr, "sigma1"), (sqr, "sigma2"))))
)

There are two caveats:

  1. We have to make sure the PatternSub won't match if any other variables in the subgraph are Marginal. For instance the following intermediate graph shouldn't match with the PatternSub above because sigma is also marginalized:
mu_rv = Marginal(Normal(0, 1))
sigma_rv = Marginal(HalfNormal(0, 1))
y_rv = Normal(mu_rv, sigma_rv)
  1. The rewrite should also fail if other variables not present in the PatternSub depend on the ones being marginalized. For instance:
mu_rv = Marginal(Normal(0, 1))
y_rv1 = Normal(mu_rv, 1)
y_rv2 = Normal(mu_rv, 1)

We cannot think about y_rv1 and y_rv2 alone because they depend on the same marginalized variable.

I think both of these cases can easily be covered by our limited PatternSub by adding specific constraints to the "sigma" variable and the "Marginal" variables. But we might want to offer a more specialized MarginalSub that adds these constraints automatically, and also registers the rewrites in the right database.

The idea is that it would be easy to extend this outside of aeppl. For instance in PyMC we have the DirichletMultinomial distribution so in a graph like:

a_rv = Dirichlet(np.ones(3), size=n)
y_rv = Multinomial(n=5, p=a)
y_vv = y_rv.clone()

factorized_joint_logprob({y_rv: y_vv})  # a is not given a value

# aesara rewrites intermediately as
a_rv = Marginal(Dirichlet(no.ones(3), size=n))
y_rv = Multinomial(n=5, p=a_rv)

The following straightforward rewrite gives us what we want:

PatternSub (
  (Multinomial, "rng1", "size", "dtype", "n", (Marginal, (Dirichlet, ("rng2", "size", "dtype", "alpha"))),
  (DirichletMultinomial, "rng1", "size", "dtype", "n", "alpha" ),
)

And again we need the constraint that "n" and "alpha" are not themselves marginalized variables, and no other variable depends on "a"

Obviously, we can still use local rewrites to replace more complicated marginal graphs, for instance involving Bernoulli or Categorical variables which can be generalized for many combinations of RandomVariables in a single function.

brandonwillard commented 2 years ago

There's no need to frame the marginalization functionality in terms of the log-probability function interface (e.g. via a keyword option); it's a coupling that doesn't provide anything constructive at the moment. We can always consider making complex/all-in-one interfaces once we've gotten the requisite functionality worked out independently.

Instead, we should start with a function—call it marginalize for now—that takes a model graph and a list of variable pairs and attempts to marginalize according to those pairs. This is a simple and general enough interface to be useful in more than one way/place.

The obvious thing is that we won't be implementing all possible symbolic marginalizations (or distributions) out there, so we need to have an API that is easy to interact with.

I'm not sure what our degree of marginalization coverage has to do with an API, aside from the APIs that are involved in the creation and application of marginalizations (i.e. our graph rewriting tools and logic). Really, if an API is needed for anything other than the aforementioned, the problem it solves/addresses needs to be apparent first.

My current idea is that aeppl figures out what RVs have to be marginalized (e.g., because they lack a Value variable, or are given a marginal flag value in the rv_values dict) and wraps those in a pseudo Op Marginal that has no logprob defined for it by default.

As I mentioned above, let's focus on the basics of marginalizing first, and get to its specific applications and automations afterward.

As an operator (in the mathematical sense), the Marginal you describe doesn't appear to have an existing analog; however, marginalization is represented quite well by an integral operator (since it is just convolution), so we should consider using that as an intermediate representation—assuming we need one for the moment.

What you describe via a PatternSub-like sexpr, i.e. Normal(Marginal(Normal(0, 1)), sigma), could be described using an expecation-like operator as follows: for Y | X ~ Normal(X, sigma) and X ~ Normal(0, 1), the "marginalization" of Y in X can be represented as E_X[Y | X] or E(Normal(X, sigma), X).

The representation you propose has the "action" of marginalization acting on only one of the terms involved, and that makes the connection between the output of the underlying action and its inputs rather ambiguous.

The PatternSub caveats you mention are also neatly addressed by the use of an Op that models the integral operator, because it requires one to be more specific about exactly what the marginalization is being applied to and the terms that are integrated.

Otherwise, I don't see the motivation for an intermediate form just yet, especially after clarifying how marginalization is represented as a graph operation. As far as implementations go, PatternSubs that match the "marginalizable" graphs directly and replace them with their marginalized forms are about as easy as it gets. All one might need to do is limit/determine to which sub-graphs the PatternSubs will be applied, and, when taking an in-graph operator approach to such filtering, that could only be adequately handled by something with at least as much specificity as the integral operator described above.

Even if we move a few steps forward and consider the case of automatic marginalization, why wouldn't we greedily marginalize? When there are many different ways to apply marginalization, we need to consider a few other things first, like orderings for the distinct marginalization outcomes (e.g. in order to determine which is "better"), how to efficiently produce outcomes per an ordering, etc.; otherwise, we might as well choose one greedy approach and go with that. (N.B. these are things that are probably better addressed in a relational context like miniKanren, or at least a framework that can lazily evaluate large/infinite sets of outcomes in non-trivial ways.)

ricardoV94 commented 2 years ago

I don't think you are proposing something very different from what I had in mind.

I think E(Normal(X, sigma), X) is a neat way to represent marginalization and addresses the caveat of "what is being integrated out?" that I was bringing up.

The representation you propose has the "action" of marginalization acting on only one of the terms involved,

I am not sure what you mean. The Marginal keyword acts like an identifier of what is being integrated out, so you could very well have multiple Marginal variables in the same graph (say if both mu and sigma are to be integrated out), and multiple RVs without Marginal that correspond to variables that have input values. Again I like the idea of the E operator, just not sure what the problem you were seeing was.

Instead, we should start with a function—call it marginalize for now—that takes a model graph and a list of variable pairs and attempts to marginalize according to those pairs. This is a simple and general enough interface to be useful in more than one way/place.

What do you mean? How is this different than factorized_joint_logrob? Is it something you call in advance to generate intermediate {rv: vv} pairs where some of the rv may be E(rv|...) operators which are then fed into factorized_joint_logrob?

The only reason I brought up the possibility of extra keywords is to disambiguate what happens to RandomVariables that don't have a value-variable in the graph (this is related to the API discussion in #85)

mu_rv = Normal(0, 1)
y_rv = Normal(mu_rv, 1) 
y_vv = y_rv.clone()

factorized_joint_logprob({y_rv: y_vv})  # mu is not given a value

Does this correspond to a logprob_graph of y where x is marginalized out or where x remains a random variable? Right now we would not have a way to distinguish between those types of graphs. One option would be to pass {rv: rv} for the cases where we want to keep RVs and do not pass any entries for RVs that we want to marginalize out (or even {rv: "marginalize"}) to be extra explicit.

Anyway, is this question why you wanted to keep the (construction) of the intermediate marginalization representation outside of factorized_joint_logprob?

tmbb commented 2 years ago

I'm resurrecting this issue because it's being proposed as something to work on during this year's Google Summer of Code.

@ricardoV94 have you given any more thought to this lately?. Personally, I wonder if aapl/aesara would be the appropriate abstraction layer where one would implement marinalization of discrete variables in PyMC models. As discussed above, implementing marinalization in aesara itself could have other benefits, but I'm thinking concretely of aesara as a compilation target for PyMC.

Maybe this would be easier to do "closer" to the surface, namely at the level of the model itself. I'm not very familiar with the compilation of PyMC models, and maybe there isn't anything between the PyMC classes and aesara, but I was just wondering. Maybe if one works at a high level, one may be able to reuse "off the shelf" symbolic manipulation libraries?

I admit my total ignorance about the PyMC internals (and aesara itself), so please excuse me if it what I'm saying doesn't make much sense.

ricardoV94 commented 2 years ago

@tmbb I don't have the time to give you a more detailed response just yet. I have been playing with this recently. You can find one example in this gist: https://gist.github.com/ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965

Do note that while aeppl is used by PyMC, it is not a library that is supposed to serve PyMC in specific.

rlouf commented 2 years ago

Personally, I wonder if aapl/aesara would be the appropriate abstraction layer where one would implement marinalization of discrete variables in PyMC models. As discussed above, implementing marinalization in aesara itself could have other benefits, but I'm thinking concretely of aesara as a compilation target for PyMC.

aesara is way more than just a compilation target for PyMC. In fact, PyMC can be seen as a wrapper around the aesara ecosystem which simplifies some operations. Imo marginalization of discrete random variables, like anything that transforms the logprob, should happen in aeppl.

tmbb commented 2 years ago

@ricardoV94 I really like the fact that the default "algebraic" operations supported by aesara make it very easy to implement variable substitution, which by itself goes very far. Just adding a possible_values field to the discrete variables (which could be set by something higher level such as PyMC) would go a long way to make this pretty usable without too much manual intervention. I'm actually pretty happy with the API if the marginalize function you've written. I just don't like the idea of having to specify the possible states in the function call.

ricardoV94 commented 2 years ago

I just don't like the idea of having to specify the possible states in the function call.

We definitely don't need to, it can be inferred for the few discrete finite support distributions that we intend to support (+ truncated discrete distributions).

The only thing is that you need to do some extra work to infer what is the support of e.g., a categorical variable. This may require some constant fold operations to find out the length of the p vector or find a way to do the marginalization with an Aesara scan instead of a Python loop.

Does that make sense?

ricardoV94 commented 2 years ago

By the way @tmbb what GSOC project/ mentors are you working with?

tmbb commented 2 years ago

I'm NOT working on GSOC, and I definitely can't commit to working on it. I only mentioned it to explain why I was re-opening the issue

On Tue, 28 Jun 2022, 20:10 Ricardo Vieira, @.***> wrote:

By the way @tmbb https://github.com/tmbb what GSOC / mentors are you working with?

— Reply to this email directly, view it on GitHub https://github.com/aesara-devs/aeppl/issues/21#issuecomment-1169119066, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAULSSR4RICWLD6MYV5QD4DVRNEZZANCNFSM47DF4XCA . You are receiving this because you were mentioned.Message ID: @.***>

ricardoV94 commented 2 years ago

I'm NOT working on GSOC, and I definitely can't commit to working on it. I only mentioned it to explain why I was re-opening the issue

I assumed so from how you introduced it.

What is your interest then? Just a feature request or something else?

tmbb commented 2 years ago

What is your interest then? Just a feature request or something else?

I was thinking of working in it, but without the GSOC commitment.

ricardoV94 commented 2 years ago

What is your interest then? Just a feature request or something else?

I was thinking of working in it, but without the GSOC commitment.

Sure, that would be very welcome

rlouf commented 2 years ago

I think marginalisation is actually out of scope for AePPL and a better fit for AeMCMC: I don't have an example in mind, but marginalising greedily might block some important rewrite paths. Anyway, marginalising out discrete RVs is a graph transformation that is independent from computing a logprob.

ricardoV94 commented 2 years ago

I don't agree at all. It makes as much sense as allowing aeppl to derive the probability of x = normal + normal or x = normal(normal) where you condition only on x. Just different forms of marginalization.

Why wouldn't that be the realm of a PPL?

rlouf commented 2 years ago

Well, it looks like we have different views of what AePPL, and in that sense we're both right within our view. I don't know what you mean by PPL, so the only thing I can do is explain what I believe AePPL is about: a library that takes an Aesara graph with random variables and return the logprob or joint_logprob for this graph. That's already a lot.

AeMCMC's purpose is different (and now that I think about it, the name is not perfect). It is about encoding RV algebra, so to speak, and use this knowledge to transform the graph that you pass to it into a mathematically equivalent representation of that graph. Representation that you can use to build Gibbs samplers for instance, but that's only one possible use case. AeMCMC already contains conjugacy relations, location-scale transforms, and a relation to marginalize discrete random variable fits right in. So do the transforms that are already in AePPL since they're not universally optimal, I have already argued that before.

When there are many different ways to apply marginalization, we need to consider a few other things first, like orderings for the distinct marginalization outcomes (e.g. in order to determine which is "better"), how to efficiently produce outcomes per an ordering, etc.; otherwise, we might as well choose one greedy approach and go with that. (N.B. these are things that are probably better addressed in a relational context like miniKanren, or at least a framework that can lazily evaluate large/infinite sets of outcomes in non-trivial ways.)

What @brandonwillard mentions in the above comment is typically the kind of question we are dealing with in AeMCMC.

ricardoV94 commented 2 years ago

Yes, and I see that very differently from: give me a logprob expression for a graph of RVs x and y, where I condition on x, and y, or just x alone (so y has to be mathematically marginalized out). I think that should be the realm of AePPL.

It has nothing to do with sampling performance or how you might decompose a model to sample more efficiently which indeed would fit in the realm of AeMCMC.

ricardoV94 commented 2 years ago

a library that takes an Aesara graph with random variables and return the logprob or joint_logprob for this graph. That's already a lot.

I think the part missing is this is not a single 1-to-1 transformation. It depends on which variables you condition on. And indeed you have to tell Aeppl which {rv: vv} pairs you want the final joint logprob to be conditioned on. Marginalization in this sense is just giving a different pair of {rv: vv} where some rvs are not present (or however you decide to encode variable marginalization) to obtain a different joint logprob graph from the same underlying RV graph.

rlouf commented 2 years ago

Yes, and I see that very differently from: give me a logprob expression for a graph of RVs x and y, where I condition on x, and y, or just x alone (so y has to be mathematically marginalized out). I think that should be the realm of AePPL.

You kind of loose in generality by marginalising automatically. I think we want to keep the possibility of logprobs with random variables open, I've seen that used in this paper for instance.

It has nothing to do with sampling performance or how you might decompose a model to sample more efficiently which indeed would fit in the realm of AeMCMC.

AeMCMC is currently headed towards something bigger than "tools for sampling"; the relations we are implementing are more fundamental than that.

To get out of that conundrum we should to transfer the relational stuff to AePPL and only keep the sampler-building logic in AeMCMC. And regarding marginalising more specifically, we would implement the relations as independent functions, and call them from joint_logprob for the RVs the user specifies, just like we do for transforms.

ricardoV94 commented 2 years ago

You kind of loose in generality by marginalising automatically. I think we want to keep the possibility of logprobs with random variables open, I've seen that used in this paper for instance.

Sure, that's the question being discussed in #85 It's just a matter of giving preference to pure RV nodes or marginalized nodes, one of them has to be the reference unless you opt for the verbose option where you must specify rv_value pairs, pure_rvs, and marginalized_rvs

ricardoV94 commented 2 years ago

To get out of that conundrum we should to transfer the relational stuff to AePPL and only keep the sampler-building logic in AeMCMC.

Definitely. Note that all the rewrites that have been implemented so far are distribution agnostic, but we do want to include distribution aware rewrites for the cases where closed form solutions exist. If that's the kind of work that is already present in AeMCMC I would invite to bring it over here.

And regarding marginalising more specifically, we would implement the relations as independent functions, and call them from joint_logprob for the RVs the user specifies, just like we do for transforms.

I agree.

ricardoV94 commented 2 years ago

unless you opt for the verbose option where you must specify rv_value pairs, pure_rvs, and marginalized_rvs

What's wrong with that in a low-ish level framework? I understand why a wrapper like PyMC would be opinionated about this, but we don't have to.

I don't think there is any problem, and it's my preference actually. It's just that we haven't decided on it explicitly yet and right now the default is allowing RVs to be in the graph (although we don't have very thorough tests, and know some issues: #174)

Just invite to discuss that in #85 as I think it's distinct from this issue

rlouf commented 2 years ago

Sure, that's the question being discussed in #85 It's just a matter of giving preference to pure RV nodes or marginalized nodes, one of them has to be the reference unless you opt for the verbose option where you must specify rv_value pairs, pure_rvs, and marginalized_rvs

I'm with Brandon on this one, why would that be wrong in a low-ish level framework where flexibility should take precedence over friendliness? I understand why a wrapper like PyMC would make opinionated decisions about this, but AePPL is not the place for that in my opinion.

I also see that implementing convolution relations are talked about in this issue, when that would currently make more sense in AeMCMC. I think we need to redefine the scope of both libraries, and again I think this would imply moving the relational stuff (the relation, not the whole automatic rewriting for samplers stuff) to AePPL. And then we can use all these rewriting capabilities to transform the logprob should we want to, and perhaps allow users to pass a set of rewrites to joint_logprob. We would keep the flexibility, and also allow the kind of behavior that you describe.

rlouf commented 2 years ago

Definitely. Note that all the rewrites that have been implemented so far are distribution agnostic, but we do want to include distribution aware rewrites for the cases where closed form solutions exist. If that's the kind of work that is already present in AeMCMC I would invite to bring it over here.

The more we talk about it the more it makes sense to me. This defines a neat hierarchy and clear separation of concerns:

If everyone is on board with that I can start moving stuff around.

ricardoV94 commented 2 years ago

Yeah I feel that's a good separation. Of course AeMCMC and AePPL will share some of that relational reasoning for their different purposes, but since you need AePPL for AeMCMC and not the other way around I would put themthe (shared reasoning logic) here.

brandonwillard commented 2 years ago

The more we talk about it the more it makes sense to me. This defines a neat hierarchy and clear separation of concerns:

  • Aesara defines the graph IR that includes random variables;
  • AePPL contains knowledge about RVs (convolution, loc-scale transforms, closed-form posteriors, etc), maybe more RVs (including the support constraint logic) that we don't want in Aesara, and ways to compute total and joint logprob. This way wrappers can directly and explicitly apply transformations before applying the logprob transforms.
  • AeMCMC uses the knowledge contained in AePPL, and sampler implementations to build efficient samplers.

To add to/clarify some of that:

AePPL primarily serves to produce log-probabilities for models, and a domain-specific IR is needed in order to expand model coverage. Domain-specific relations (e.g. random variable relationships, operator lifting, etc.) are used as part of the canonicalization of this domain-specific IR, and have the effect of increasing coverage without increasing the "size" of the IR. This mostly explains why we have the rewrites that we do, and it helps tell us which kinds of rewrites aren't appropriate for AePPL.

Naturally, AeMCMC has its own domain-specific IR, and it primarily serves to produce samplers.

Marginalization seems harder to place because it can be used to produce a log-probability, but doesn't fit as an IR choice (or, more accurately, as part of AePPL's equational theory), and this is what makes it less suitable for AePPL. The reason it doesn't fit is that it requires the introduction of another operator (i.e. an expectation/integral/sum), and that other operator doesn't have much of a place in AePPL.

ricardoV94 commented 2 years ago

I don't follow why doesn't it have a place in Aeppl? Because it requires a new operator?

brandonwillard commented 2 years ago

I don't follow why doesn't it have a place in Aeppl? Because it requires a new operator?

Yes, that's part of it, as well as the fact that the new operator doesn't really serve the goals of AePPL except indirectly.

ricardoV94 commented 2 years ago

So if I want a logp with a marginalized variable I have to go to aemcmc? Why, if that library is concerned with creating samplers?

I would definitely want that feature to be provided by aeppl since I am only interested in logps.

rlouf commented 2 years ago

In the current setting, if we add marginalisation to AeMCMC (which was my first incline), AePPL will need to call AeMCMC for these rewrites. While AeMCMC depends on AePPL already for the logprob. That's not a good sign design-wise, but I can live with it if we have a very good reason to do so. On the other hand, adding to AePPL only the rewrites that serve its purpose is also not desirable. Hence the proposal to carve out the rewrite part of AeMCMC to only keep the sampler building part. But the solution may be a different library that serves both AePPL and AeMCMC's purposes?

I think this discussion will run in circles until we have a design document that states clearly the goal of each library and the trade-offs. @ricardoV94 and I have mentioned the problems with the current solution, but maybe we're missing the bigger picture.

brandonwillard commented 2 years ago

So far, I've only attempted to partially "formalize" the relevant considerations, because I believe that's what we should be doing. Part of that process involves establishing definitions or "axioms", and that can be fairly arbitrary, so keep that in mind.

For instance, AePPL's objectives don't need to be exclusively "produce log-probabilities". We can just as well say that AePPL's objectives include specific types of model manipulation Ops/IR, which could encompass marginalization.

Anyway, if we don't attempt to reason things in a similar way, then we'll definitely run in circles, and slowly evolve toward a confusing mess.

brandonwillard commented 2 years ago

Also, AePPL doesn't/wouldn't inherently need to call AeMCMC for anything.

rlouf commented 1 year ago

Going back to this, it would make sense to implement a marginalize operator that acts on AePPL's intermediate representation.

It would then make sense to refactor the internals of AePPL a little, and turn the user-facing joint_logprob into a high-level function that gives a friendly interface for operations on the intermediate representation. In pseudo-code:

def joint_logprob(*rvs, *, to_marginalize):
    rvs_to_values = aeppl.internals.create_value_variables(rvs)
    ir = aeppl.internals.to_ir(rvs_to_values)
    marginalized_ir = aeppl.internals.marginalize(ir, to_marginalize)
    logdensity = aeppl.internals.disintegrate(marginalized_ir)
    return logdensity
brandonwillard commented 1 year ago

I should probably point out to anyone reading this that expectation operators like the one I described above have been fairly common to symbolic math libraries with support for probability theory, as well as PPLs implemented well before AePPL and PyMC.

Likewise, long before this discussion, we already had implementations of "automatic integration/convolutions" in the same Aesara/Theano context via symbolic-pymc's and—more recently—AeMCMC's automatic conjugations, for example. Likewise, automatic integration/marginalization and, more specifically, Rao-Blackwellization has been an oft vocalized target of all this work since long before it started. Sadly, the framing of this feature presents the idea without any of that context and differentiation. This wrongly gives the impression that the idea originated here and specifically for use in AePPL, and I believe that severely hurts the interpretation of the idea itself and this discussion.

That said, I was initially trying to make this discussion more instructional via basic critiques of the ways in which we could represent and manipulate generalized marginalization in IR form. Unfortunately, I failed to effectively instigate that line of inquiry and the discussion has/had veered off into somewhat defensive and unconstructive API design areas. I apologize.

From here on, let's designate this issue as a discussion about potential IR for an expectation Op and only that. Issues/Discussions can be opened for other aspects of this functionality, though.