aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Random variables present in logp graph from `Join` dispatch #149

Closed larryshamalama closed 2 years ago

larryshamalama commented 2 years ago

With the logprob dispatch for the CumOp, there is ongoing work to redefine the Gaussian Random Walk distribution in PyMC as a cumulative sum of distributions and have AePPL automatically retrieve its logp graph (see PR 5814). A restriction in PyMC is that random variables must not be present in the logp graph as per lines here. However, the Join Op looks into tensor shapes which, as a consequence, does not replace the random variables in the logp graph with their value variable counterpart.

One way to address this issue would be to introduce a constant folding before passing variable shapes as an argument into splits_size:

https://github.com/aesara-devs/aeppl/blob/cc78f30b5ed89e5b247b88d697467a76cc1e424e/aeppl/tensor.py#L112-L117

Any thoughts about this approach?

CC @ricardoV94 @brandonwillard

ricardoV94 commented 2 years ago

as a consequence, does not replace the random variables in the logp graph with their value variable counterpart.

Small precision point: These variables don't have values by definition (their join is the thing being valued) so they could never be replaced.

For Aeppl users I also see some advantages in not including the original joined random variables in the logp graph.

larryshamalama commented 2 years ago

I tried the following:

shapes_stacked = at.stack([base_var.shape[axis] for base_var in base_vars])

base_var_shapes = optimize_graph(
    shapes_stacked,
    custom_opt=topo_constant_folding,
)

split_values = at.split(
    value,
    splits_size=base_var_shapes,
    n_splits=len(base_vars),
    axis=axis,
)

but, in PyMC, I'm still obtaining that random variables are present in the logp graph. For instance, I print the graph of shapes_stacked and base_var_shapes and I obtain the following.

Graph of shapes_stacked

Subtensor{int8} [id A]
 |TensorConstant{(1,) of 1} [id B]
 |ScalarFromTensor [id C]
   |TensorConstant{0} [id D]
Subtensor{int8} [id E]
 |Shape [id F]
 | |normal_rv{0, (0, 0), floatX, False}.1 [id G]
 |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16D3E6580>) [id H]
 |   |TensorConstant{(1,) of 10} [id I]
 |   |TensorConstant{11} [id J]
 |   |TensorConstant{(1,) of 0.0} [id K]
 |   |TensorConstant{(1,) of 1.0} [id L]
 |ScalarFromTensor [id M]
   |TensorConstant{0} [id D]

Graph of base_var_shapes

MakeVector{dtype='int64'} [id A]
 |TensorConstant{1} [id B]
 |InplaceDimShuffle{} [id C]
   |Shape [id D]
     |normal_rv{0, (0, 0), floatX, False}.1 [id E] <- problem is here?
       |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16D3E6580>) [id F]
       |TensorConstant{(1,) of 10} [id G]
       |TensorConstant{11} [id H]
       |TensorConstant{(1,) of 0.0} [id I]
       |TensorConstant{(1,) of 1.0} [id J]

I believe that the presence of normal_rv{0, (0, 0), floatX, False}.1 [id E] following constant folding makes the problem persist... Was this behaviour expected or reasonable?

ricardoV94 commented 2 years ago

Try to add a ShapeFeature to the graph you are doing constant folding over

https://github.com/aesara-devs/aesara/blob/7393b7441601eaad98bc0cb494aa8fba2ea4bf6a/aesara/tensor/basic.py#L1452

larryshamalama commented 2 years ago

Thanks, it works :) I can spin a PR on this

Edit: While the solution above works, discussion about the presence of random variables in a logp graph is ongoing.

brandonwillard commented 2 years ago

It looks like another pass of canonicalizations (on a FunctionGraph with the ShapeFeature attached) is needed in PyMC, so I'm closing this for now.