aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Change use of "opt[imize]" to "rewrite" #161

Closed brandonwillard closed 2 years ago

brandonwillard commented 2 years ago

This PR updates our use of "opt[imize]" to "rewrite".

brandonwillard commented 2 years ago

Here's the chain of rewriting that leads to the error seen in CI:

First, the log-probability graph for the categorical:

Check{0 <= p <= 1} [id A] 12
 |Elemwise{switch,no_inplace} [id B] 11
 | |Elemwise{and_,no_inplace} [id C] 10
 | | |TensorConstant{(1, 1, 3) of True} [id D]
 | | |TensorConstant{(1, 1, 3) of True} [id E]
 | |Elemwise{log,no_inplace} [id F] 9
 | | |AdvancedSubtensor [id G] 8
 | |   |InplaceDimShuffle{2,0,1} [id H] 7
 | |   | |Elemwise{true_div,no_inplace} [id I] 2
 | |   |   |<TensorType(float64, (1, 1, 2))> [id J]
 | |   |   |InplaceDimShuffle{0,1,x} [id K] 1
 | |   |     |Sum{axis=[2], acc_dtype=float64} [id L] 0
 | |   |       |<TensorType(float64, (1, 1, 2))> [id J]
 | |   |TensorConstant{[[[0 1 1]]]} [id M]
 | |   |TensorConstant{(1, 1, 1) of 0} [id N]
 | |   |TensorConstant{(1, 1, 1) of 0} [id O]
 | |TensorConstant{(1, 1, 1) of -inf} [id P]
 |All [id Q] 6
 | |Elemwise{ge,no_inplace} [id R] 5
 |   |Elemwise{true_div,no_inplace} [id I] 2
 |   |TensorConstant{(1, 1, 1) of 0.0} [id S]
 |All [id T] 4
   |Elemwise{le,no_inplace} [id U] 3
     |Elemwise{true_div,no_inplace} [id I] 2
     |TensorConstant{(1, 1, 1) of 1.0} [id V]

After a few rewrites, we get the following:

Check{0 <= p <= 1} [id A] 41
 |Alloc [id B] 40
 | |Elemwise{log,no_inplace} [id C] 39
 | | |AdvancedSubtensor [id D] 38
 | |   |InplaceDimShuffle{2,0,1} [id E] 37
 | |   | |Elemwise{true_div,no_inplace} [id F] 2
 | |   |   |<TensorType(float64, (1, 1, 2))> [id G]
 | |   |   |InplaceDimShuffle{0,1,x} [id H] 1
 | |   |     |Sum{axis=[2], acc_dtype=float64} [id I] 0
 | |   |       |<TensorType(float64, (1, 1, 2))> [id G]
 | |   |TensorConstant{[[[0 1 1]]]} [id J]
 | |   |TensorConstant{(1, 1, 1) of 0} [id K]
 | |   |TensorConstant{(1, 1, 1) of 0} [id K]
 | |TensorFromScalar [id L] 36
 | | |ScalarConstant{1} [id M]
 | |Subtensor{int64} [id N] 35
 | | |MakeVector{dtype='int64'} [id O] 13
 | | | |TensorConstant{1} [id P]
 | | | |Elemwise{int_div,no_inplace} [id Q] 7
 | | | | |TensorConstant{1} [id P]
 | | | | |TensorConstant{1} [id P]
 | | | |TensorFromScalar [id R] 12
 | | |   |Assert{msg=Could not broadcast dimensions} [id S] 11
 | | |     |ScalarConstant{3} [id T]
 | | |     |TensorFromScalar [id U] 10
 | | |       |EQ [id V] 9
 | | |         |ScalarFromTensor [id W] 8
 | | |         | |Elemwise{int_div,no_inplace} [id Q] 7
 | | |         |ScalarConstant{3} [id T]
 | | |ScalarConstant{1} [id M]
 | |TensorFromScalar [id X] 34
 |   |Assert{msg=Could not broadcast dimensions} [id Y] 33
 |     |Abs [id Z] 24
 |     | |maximum [id BA] 23
 |     |   |Switch [id BB] 22
 |     |   | |EQ [id BC] 21
 |     |   | | |ScalarFromTensor [id BD] 20
 |     |   | | | |Subtensor{int64} [id BE] 19
 |     |   | | |   |TensorConstant{[1 1 3]} [id BF]
 |     |   | | |   |ScalarConstant{2} [id BG]
 |     |   | | |ScalarConstant{1} [id BH]
 |     |   | |neg [id BI] 16
 |     |   | | |ScalarConstant{1} [id M]
 |     |   | |ScalarFromTensor [id BD] 20
 |     |   |Switch [id BJ] 18
 |     |     |EQ [id BK] 17
 |     |     | |ScalarFromTensor [id BL] 15
 |     |     | | |Subtensor{int64} [id BM] 14
 |     |     | |   |MakeVector{dtype='int64'} [id O] 13
 |     |     | |   |ScalarConstant{2} [id BG]
 |     |     | |ScalarConstant{1} [id BH]
 |     |     |neg [id BI] 16
 |     |     |ScalarFromTensor [id BL] 15
 |     |TensorFromScalar [id BN] 32
 |       |AND [id BO] 31
 |         |OR [id BP] 30
 |         | |EQ [id BQ] 29
 |         | | |Switch [id BB] 22
 |         | | |neg [id BI] 16
 |         | |EQ [id BR] 28
 |         |   |Switch [id BB] 22
 |         |   |Abs [id Z] 24
 |         |OR [id BS] 27
 |           |EQ [id BT] 26
 |           | |Switch [id BJ] 18
 |           | |neg [id BI] 16
 |           |EQ [id BU] 25
 |             |Switch [id BJ] 18
 |             |Abs [id Z] 24
 |All [id BV] 6
 | |Elemwise{ge,no_inplace} [id BW] 5
 |   |Elemwise{true_div,no_inplace} [id F] 2
 |   |TensorConstant{(1, 1, 1) of 0.0} [id BX]
 |All [id BY] 4
   |Elemwise{le,no_inplace} [id BZ] 3
     |Elemwise{true_div,no_inplace} [id F] 2
     |TensorConstant{(1, 1, 1) of 1.0} [id CA]

The switch replacement is due to local_useless_switch, where it is replaced by an Alloc that apparently has bad shape inputs.

The shapes of the switch's inputs are as follows:

[(TensorConstant{1}, TensorConstant{1}, TensorConstant{3}),
 (TensorConstant{1}, Elemwise{int_div,no_inplace}.0, TensorFromScalar.0),
 (TensorConstant{1}, TensorConstant{1}, TensorConstant{1})]

The problem is the second input's last dimension's shape value (i.e. the TensorFromScalar). Here are that input's shape graphs:

TensorConstant{1} [id A]
Elemwise{int_div,no_inplace} [id B]
 |Elemwise{mul,no_inplace} [id C]
 | |TensorConstant{1} [id D]
 |Elemwise{mul,no_inplace} [id E]
   |TensorConstant{1} [id D]
   |TensorConstant{1} [id D]
TensorFromScalar [id F]
 |Assert{msg=Could not broadcast dimensions} [id G]
   |ScalarConstant{3} [id H]
   |TensorConstant{False} [id I]

In other words, its shape is (1, 1, ?), where the last dimension's shape value should be the 3 in the Assert, but the Assert condition is already False.

That input has the following graph:

Elemwise{log,no_inplace} [id A] <TensorType(float64, (1, None, None))>
 |AdvancedSubtensor [id B] <TensorType(float64, (1, None, None))>
   |InplaceDimShuffle{2,0,1} [id C] <TensorType(float64, (None, 1, 1))>
   | |Elemwise{true_div,no_inplace} [id D] <TensorType(float64, (1, 1, None))>
   |   |<TensorType(float64, (1, 1, 2))> [id E] <TensorType(float64, (1, 1, 2))>
   |   |InplaceDimShuffle{0,1,x} [id F] <TensorType(float64, (1, 1, 1))>
   |     |Sum{axis=[2], acc_dtype=float64} [id G] <TensorType(float64, (1, 1))>
   |       |<TensorType(float64, (1, 1, 2))> [id E] <TensorType(float64, (1, 1, 2))>
   |TensorConstant{[[[0 1 1]]]} [id H] <TensorType(int64, (1, 1, 3))>
   |TensorConstant{(1, 1, 1) of 0} [id I] <TensorType(int64, (1, 1, 1))>
   |TensorConstant{(1, 1, 1) of 0} [id I] <TensorType(int64, (1, 1, 1))>

Notice how the true_div node with ID D can't seem to infer a static output shape of (1, 1, 2); we need to fix that. Regardless, the node at ID C should have a shape of (2, 1, 1) and, since the outer AdvancedSubtensor performs x[i, np.zeros((1, 1, 1)), np.zeros((1, 1, 1))], where x.shape is (2, 1, 1) and i.shape is (1, 1, 3), the output shape is (1, 1, 3). There appears to be a miscalculation in AdvancedSubtensor.infer_shape.

In AdvancedSubtensor.infer_shape, the output shape computed is as follows:

ScalarConstant{1} [id A]
Elemwise{int_div,no_inplace} [id B]
 |Elemwise{mul,no_inplace} [id C]
 | |TensorConstant{1} [id D]
 |Elemwise{mul,no_inplace} [id E]
   |TensorConstant{1} [id F]
   |TensorConstant{1} [id G]
Assert{msg=Could not broadcast dimensions} [id H]
 |ScalarConstant{3} [id I]
 |TensorFromScalar [id J]
   |EQ [id K]
     |ScalarFromTensor [id L]
     | |Elemwise{int_div,no_inplace} [id M]
     |   |Elemwise{mul,no_inplace} [id N]
     |   | |TensorConstant{1} [id O]
     |   |Elemwise{mul,no_inplace} [id P]
     |     |TensorConstant{1} [id Q]
     |     |TensorConstant{1} [id R]
     |ScalarConstant{3} [id I]

and the input shapes are

[(TensorConstant{2}, TensorConstant{1}, TensorConstant{1}),
 (TensorConstant{1}, TensorConstant{1}, TensorConstant{3}),
 (TensorConstant{1}, Elemwise{int_div,no_inplace}.0, TensorConstant{1}),
 (TensorConstant{1}, TensorConstant{1}, Elemwise{int_div,no_inplace}.0)]

The last input's shape graphs are as follows:

TensorConstant{1} [id A]
TensorConstant{1} [id B]
Elemwise{int_div,no_inplace} [id C]
 |Elemwise{mul,no_inplace} [id D]
 | |TensorConstant{1} [id E]
 |Elemwise{mul,no_inplace} [id F]
   |TensorConstant{1} [id G]
   |TensorConstant{1} [id H]

In other words, they're all ones, so what's the problem with that last dimension?

Looks like the problem is in indexed_result_shape (i.e. the function used by AdvancedSubtensor.infer_shape to derive the output shape).

brandonwillard commented 2 years ago

OK, the issue is actually in broadcast_shape.

brandonwillard commented 2 years ago

https://github.com/aesara-devs/aesara/pull/1138 should fix the issue.

codecov[bot] commented 2 years ago

Codecov Report

Base: 94.92% // Head: 94.92% // No change to project coverage :thumbsup:

Coverage data is based on head (9e681a4) compared to base (74a937d). Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #161 +/- ## ======================================= Coverage 94.92% 94.92% ======================================= Files 12 12 Lines 1852 1852 Branches 275 275 ======================================= Hits 1758 1758 Misses 53 53 Partials 41 41 ``` | [Impacted Files](https://codecov.io/gh/aesara-devs/aeppl/pull/161?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [aeppl/cumsum.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvY3Vtc3VtLnB5) | `100.00% <100.00%> (ø)` | | | [aeppl/joint\_logprob.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvam9pbnRfbG9ncHJvYi5weQ==) | `96.92% <100.00%> (ø)` | | | [aeppl/logprob.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvbG9ncHJvYi5weQ==) | `98.02% <100.00%> (ø)` | | | [aeppl/mixture.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvbWl4dHVyZS5weQ==) | `98.00% <100.00%> (ø)` | | | [aeppl/printing.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvcHJpbnRpbmcucHk=) | `89.68% <100.00%> (ø)` | | | [aeppl/rewriting.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvcmV3cml0aW5nLnB5) | `94.00% <100.00%> (ø)` | | | [aeppl/scan.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvc2Nhbi5weQ==) | `94.73% <100.00%> (ø)` | | | [aeppl/tensor.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvdGVuc29yLnB5) | `85.71% <100.00%> (ø)` | | | [aeppl/transforms.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvdHJhbnNmb3Jtcy5weQ==) | `96.43% <100.00%> (ø)` | | | [aeppl/truncation.py](https://codecov.io/gh/aesara-devs/aeppl/pull/161/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvdHJ1bmNhdGlvbi5weQ==) | `98.27% <100.00%> (ø)` | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.