Closed larryshamalama closed 2 years ago
I think the code in the notebook is short enough to include it directly in your comment (it's easier to follow when possible). Besides that do you have a specific question/problem or is this just work in progress?
I think the code in the notebook is short enough to include it directly in your comment (it's easier to follow when possible). Besides that do you have a specific question/problem or is this just work in progress?
Both! I'm writing them as a code review, I'm just a bit slow
Edit: Done! And @rlouf thanks for the suggestion, I included it in the review below
I spent the time since our last meeting thinking about this problem from a mathematical perspective (and other PyMC-related things). Most of my progress is not yet in code.
I believe that graphs of nested switch
es can be thought as binary trees where each node has 0 or 2 children. In the case of 0 children, that is if the node is a component and, in the case of 2 children, the node is another switch
statement. To my understanding, the problem becomes the following: given a binary tree with $k$ switch
es, $k$ index variables $I_1, \dots, I_k$ and $k + 1$ leaves $A1, \dots, A{k+1}$, can we come up with a general function $h: \{0, 1\}^k \rightarrow \{0, \dots, k \}$ such that $h(I_1, \dots, I_k)$ yields the correct node $A_j$ as in the binary tree defined by nested switch
es?
Here are three basic examples that I've written by hand. All of these are graphs of depth 2, but it's a start. Let srng = at.random.RandomStream(seed=2320)
in all examples below. Each leaf is associated an integer to represent its position in the eventual concatenation of all components, i.e. leaves in the tree.
I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")
C_rv = at.switch(I2_rv, A_rv, B_rv)
C_rv.name = "C"
I1_rv = srng.bernoulli(0.5, name="I2")
D_rv = srng.normal(10, 0.1, name="D")
Z_rv = at.switch(I1_rv, C_rv, D_rv)
Z_rv.name = "Z"
*Below, $I_2 = 1$ label missing for C/A edge and $I_2 = 0$ label missing for C/B edge
Here, $h(I_1, I_2) = I_1 + I_1 I_2$.
I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")
D_rv = at.switch(I2_rv, A_rv, B_rv)
D_rv.name = "D"
I1_rv = srng.bernoulli(0.5, name="I1")
D_rv = srng.normal(10, 0.1, name="D")
Z_rv = at.switch(I1_rv, C_rv, D_rv)
Z_rv.name = "Z"
Here, $h(I_1, I_2) = 2I_1 + I_2(1 - I_1)$.
I2_rv = srng.bernoulli(0.5, name="I2")
A_rv = srng.normal(-5, 0.1, name="A")
B_rv = srng.normal(5, 0.1, name="B")
Z1_rv = at.switch(I2_rv, A_rv, B_rv)
Z1_rv.name = "Z1"
I3_rv = srng.bernoulli(0.5, name="I3")
C_rv = srng.normal(-10, 0.1, name="C")
D_rv = srng.normal(10, 0.1, name="D")
Z2_rv = at.switch(I3_rv, C_rv, D_rv)
Z2_rv.name = "Z2"
I1_rv = srng.bernoulli(0.5, name="I1")
Z_rv = at.switch(I1_rv, Z1_rv, Z2_rv)
Z_rv.name = "Z"
Here, $h(I_1, I_2, I_3) = (2I_1 + I_2)I_1 + I_3(1 - I_1)$.
Happy to hear if this formulation of the problem is on the right track and any other general thoughts.
I believe that graphs of nested
switch
es can be thought as binary trees where each node has 0 or 2 children.
Yes, your high-level scalar binary arithmetic characterization looks like it's on track, but don't forget that switch
es take arbitrary tensors as arguments, which means that you'll need to extend your characterization to—at least—vector spaces.
Regardless, the things that need to be addressed for an implementation often have more to do with the graph traversal and reuse of existing logic.
For instance, it might not be reasonable to walk a graph and attempt to identify nested switch
subgraphs and directly implement the binary arithmetic you've outlined here. The end result could easily involve a lot of custom logic and make this feature very difficult to maintain (e.g. see Scan
in Aesara).
Ideally, a "local" rewrite of the relevant switch
es would suffice, and, when those rewrites are applied repeatedly, they would cover all nested switch
-containing graph cases as well.
This is why the indexing/Subtensor
-based approach is so appealing. If we can convert switch
es into equivalent Subtensor
graphs (e.g. switch(a, b, c)
converted to stack([b, c])[a_new]
), then we can reuse all the existing Subtensor
-based mixture code and address the nested switch
cases via stack
s (i.e. Join
or MakeVector
Op
s) and Subtensor
rewrites—some of which may already exist. This approach has the added benefit that any missing stack
/Subtensor
rewrites we implement for these purposes will then be available for more general use, which would also expand the capabilities of Aesara more broadly.
Actually, if I'm understanding this correctly, you could also be attempting to describe the logic for nested mixtures, which might be helpful for rewrites that target our MixtureRV
IR more directly.
More specifically, once we identify stack([b, c])[a_new]
as a mixture, it's turned into a single MixtureRV
node, so nested switch
es would ultimately look like mixture(a_new, mixture(a_2, b_2, c_2), mixture(a_3, b_3, c_3))
, and those could be rewritten much more directly, but the question is "How?"
Base: 94.92% // Head: 94.94% // Increases project coverage by +0.01%
:tada:
Coverage data is based on head (
a724edb
) compared to base (8b298d1
). Patch coverage: 96.15% of modified lines in pull request are covered.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Okay I see... I was wondering. I just re-polished test_switch_mixture
in test_mixture.py
. Should this PR wait until #161 is merged to check if all tests pass?
I've rebased and squashed this, so you'll need to pull your remote.
Looks like this needs tests for the test value and name parts.
Thanks, I just added one in test_compute_test_value
Closes #77.
Currently, I am opening this PR to ask for pointers. So far, many lines are copy pasted from
mixture_replace
in the same file. Hopefully, this attempt is on the right track, although it is not working at the moment. The output of a mixture defined byswitch
is showcased in this gist