aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Introduce graph rewrite for mixture sub-graphs defined via `Switch` #154

Closed larryshamalama closed 2 years ago

larryshamalama commented 2 years ago

Closes #77.

Currently, I am opening this PR to ask for pointers. So far, many lines are copy pasted from mixture_replace in the same file. Hopefully, this attempt is on the right track, although it is not working at the moment. The output of a mixture defined by switch is showcased in this gist

rlouf commented 2 years ago

I think the code in the notebook is short enough to include it directly in your comment (it's easier to follow when possible). Besides that do you have a specific question/problem or is this just work in progress?

larryshamalama commented 2 years ago

I think the code in the notebook is short enough to include it directly in your comment (it's easier to follow when possible). Besides that do you have a specific question/problem or is this just work in progress?

Both! I'm writing them as a code review, I'm just a bit slow

Edit: Done! And @rlouf thanks for the suggestion, I included it in the review below

larryshamalama commented 2 years ago

I spent the time since our last meeting thinking about this problem from a mathematical perspective (and other PyMC-related things). Most of my progress is not yet in code.

Problem

I believe that graphs of nested switches can be thought as binary trees where each node has 0 or 2 children. In the case of 0 children, that is if the node is a component and, in the case of 2 children, the node is another switch statement. To my understanding, the problem becomes the following: given a binary tree with $k$ switches, $k$ index variables $I_1, \dots, I_k$ and $k + 1$ leaves $A1, \dots, A{k+1}$, can we come up with a general function $h: \{0, 1\}^k \rightarrow \{0, \dots, k \}$ such that $h(I_1, \dots, I_k)$ yields the correct node $A_j$ as in the binary tree defined by nested switches?

Examples

Here are three basic examples that I've written by hand. All of these are graphs of depth 2, but it's a start. Let srng = at.random.RandomStream(seed=2320) in all examples below. Each leaf is associated an integer to represent its position in the eventual concatenation of all components, i.e. leaves in the tree.

Example 1

I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")

C_rv = at.switch(I2_rv, A_rv, B_rv)
C_rv.name = "C"

I1_rv = srng.bernoulli(0.5, name="I2")
D_rv = srng.normal(10, 0.1, name="D")

Z_rv = at.switch(I1_rv, C_rv, D_rv)
Z_rv.name = "Z"

*Below, $I_2 = 1$ label missing for C/A edge and $I_2 = 0$ label missing for C/B edge

image

Here, $h(I_1, I_2) = I_1 + I_1 I_2$.

Example 2

I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")

D_rv = at.switch(I2_rv, A_rv, B_rv)
D_rv.name = "D"

I1_rv = srng.bernoulli(0.5, name="I1")
D_rv = srng.normal(10, 0.1, name="D")

Z_rv = at.switch(I1_rv, C_rv, D_rv)
Z_rv.name = "Z"

image

Here, $h(I_1, I_2) = 2I_1 + I_2(1 - I_1)$.

Example 3

I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")

Z1_rv = at.switch(I2_rv, A_rv, B_rv)
Z1_rv.name = "Z1"

I3_rv = srng.bernoulli(0.5, name="I3")
C_rv = srng.normal(-10, 0.1, name="C")
D_rv = srng.normal(10, 0.1, name="D")

Z2_rv = at.switch(I3_rv, C_rv, D_rv)
Z2_rv.name = "Z2"

I1_rv = srng.bernoulli(0.5, name="I1")
Z_rv = at.switch(I1_rv, Z1_rv, Z2_rv)
Z_rv.name = "Z"
image

Here, $h(I_1, I_2, I_3) = (2I_1 + I_2)I_1 + I_3(1 - I_1)$.

Comments

Happy to hear if this formulation of the problem is on the right track and any other general thoughts.

brandonwillard commented 2 years ago

I believe that graphs of nested switches can be thought as binary trees where each node has 0 or 2 children.

Yes, your high-level scalar binary arithmetic characterization looks like it's on track, but don't forget that switches take arbitrary tensors as arguments, which means that you'll need to extend your characterization to—at least—vector spaces.

Regardless, the things that need to be addressed for an implementation often have more to do with the graph traversal and reuse of existing logic.

For instance, it might not be reasonable to walk a graph and attempt to identify nested switch subgraphs and directly implement the binary arithmetic you've outlined here. The end result could easily involve a lot of custom logic and make this feature very difficult to maintain (e.g. see Scan in Aesara).

Ideally, a "local" rewrite of the relevant switches would suffice, and, when those rewrites are applied repeatedly, they would cover all nested switch-containing graph cases as well.

This is why the indexing/Subtensor-based approach is so appealing. If we can convert switches into equivalent Subtensor graphs (e.g. switch(a, b, c) converted to stack([b, c])[a_new]), then we can reuse all the existing Subtensor-based mixture code and address the nested switch cases via stacks (i.e. Join or MakeVector Ops) and Subtensor rewrites—some of which may already exist. This approach has the added benefit that any missing stack/Subtensor rewrites we implement for these purposes will then be available for more general use, which would also expand the capabilities of Aesara more broadly.

brandonwillard commented 2 years ago

Actually, if I'm understanding this correctly, you could also be attempting to describe the logic for nested mixtures, which might be helpful for rewrites that target our MixtureRV IR more directly.

More specifically, once we identify stack([b, c])[a_new] as a mixture, it's turned into a single MixtureRV node, so nested switches would ultimately look like mixture(a_new, mixture(a_2, b_2, c_2), mixture(a_3, b_3, c_3)), and those could be rewritten much more directly, but the question is "How?"

codecov[bot] commented 2 years ago

Codecov Report

Base: 94.92% // Head: 94.94% // Increases project coverage by +0.01% :tada:

Coverage data is based on head (a724edb) compared to base (8b298d1). Patch coverage: 96.15% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #154 +/- ## ========================================== + Coverage 94.92% 94.94% +0.01% ========================================== Files 12 12 Lines 1852 1878 +26 Branches 275 280 +5 ========================================== + Hits 1758 1783 +25 Misses 53 53 - Partials 41 42 +1 ``` | [Impacted Files](https://codecov.io/gh/aesara-devs/aeppl/pull/154?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [aeppl/mixture.py](https://codecov.io/gh/aesara-devs/aeppl/pull/154/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvbWl4dHVyZS5weQ==) | `97.72% <96.15%> (-0.28%)` | :arrow_down: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

larryshamalama commented 2 years ago

Okay I see... I was wondering. I just re-polished test_switch_mixture in test_mixture.py. Should this PR wait until #161 is merged to check if all tests pass?

brandonwillard commented 2 years ago

I've rebased and squashed this, so you'll need to pull your remote.

larryshamalama commented 2 years ago

Looks like this needs tests for the test value and name parts.

Thanks, I just added one in test_compute_test_value