aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 21 forks source link

Make `DimShuffle`d variables measurable #150

Open larryshamalama opened 2 years ago

larryshamalama commented 2 years ago

An example of what should be measurable is the following:

import aesara.tensor as at
from aeppl import joint_logprob

x = at.random.dirichlet([[1, 1, 1], [100, 100, 100]], name="x")
x_transpose = x.T

joint_logprob({x_transpose: x_transpose.copy()})

The transpose is one example of a DimShuffle under the hood, but there are three categories of operations that make use of DimShuffles:

To my understand, each of these will require to be handled separately, but probably in a similar fashion.

CC @ricardoV94 @brandonwillard

ricardoV94 commented 2 years ago

Definitely useful 😉

brandonwillard commented 2 years ago

This should be covered by aesara.tensor.random.opt.local_dimshuffle_rv_lift in a lot of cases, since it would lift the DimShuffle to the arguments of the RandomVariable.

local_dimshuffle_rv_lift may not currently support that distribution, though, in which case the requisite changes might only involve an extension to that rewrite (e.g. one that covers multivariate distributions). (As @ricardoV94 mentioned below, this doesn't apply to Dirichlet distributions, due to the dimension-related constraints involved in the mapping from inputs to output/support.)

Also, I believe there's a clients-based restriction hard coded into local_dimshuffle_rv_lift; that will probably need to be made configurable, since the concerns it addresses are specific to maintaining the consistency of model graph evaluations (i.e. sampling from a model), and not our AePPL intermediate forms.

ricardoV94 commented 2 years ago

You cannot lift a dimshuffle through the core dimension of a distribution parameters, at least not for all distributions.

There is no way to lift dirichlet([[1, 2, 3]]).T because dirichlet([[1], [2], [3]]) is not even valid.

And in the first example it would be valid but have a completely different meaning.

You could however extend the lift if the Dimshuffle only affects non-core dimensions

brandonwillard commented 2 years ago

You could however extend the lift if the Dimshuffle only affects non-core dimensions

Yes, the current version of local_dimshuffle_rv_lift even takes this into consideration, which is why the multivariate cases are disabled (because they may need special logic or not be possible).

brandonwillard commented 2 years ago

Now that I think about it, there may be a way to make it always possible to lift these operations.

The reason we can't lift the DimShuffle through DirichletRV is that DirichletRV fixes the last dimension of its inputs as the core/support dimension, so that z.T for z = dirichlet(a) results in a sample that has no possible corresponding f(a). However, if we parameterize DirichletRV on its core/support dimension we get z == dirichlet(a, c=-1) and z.T == dirichlet(a.T, c=0).

ricardoV94 commented 2 years ago

Why not just DimShuffle the value variables and evaluate them at the original un-DimShuffled measurable variables?

That will also work for measurable variables that are not pure RandomVariables in general.

brandonwillard commented 2 years ago

Why not just DimShuffle the value variables and evaluate it at the original untransposed variable?

While that might work for the purpose of constructing a log-probability function—in the same way that incsubtensor_rv_replace does—it won't help with the more general issues produced by DimShuffles.

For instance, DimShuffles seriously complicate the identification of subgraphs, because one needs to account for both their presence and lack thereof when matching. A DimShuffle in "front" of a non-valued, measurable mixture component can prevent AePPL from identifying it as measurable, and a value-variable-based approach wouldn't help.

The problems associated with lifting DimShuffles through RandomVariables are also very related to the problems involved with lifting all other Ops through RandomVariables, so, solutions to the latter can be solutions to the former, and lifting other Ops can have much clearer performance benefits (e.g. https://github.com/aesara-devs/aesara/issues/151).

In other words, DimShuffle lifting is more broadly applicable and beneficial, so almost any reason to improve those capabilities is a good one.

That will also work for measurable variables that are not pure RandomVariables in general.

Those RandomVariable lifting rewrites can—and should—be refactored to apply to arbitrary measurable variables. Currently, the rewrites only effectively make use of a RandomVariable's size parameter, which can just as well be ignored/empty; otherwise, only the Op's inputs are needed.

When we make the multivariate extensions mentioned here, RandomVariable.ndim_supp and RandomVariable.ndim_params information will probably be needed, but those are properties we could and should be adding to our MeasurableVariable interface anyway. Again, that information is really only necessary for multivariate measurable variables, so those additions wouldn't even affect most/all of our current non-RandomVariable MeasurableVariable classes.

ricardoV94 commented 2 years ago

A DimShuffle in "front" of a non-valued, measurable mixture component can prevent AePPL from identifying it as measurable, and a value-variable-based approach wouldn't help.

I don't see why this would be the case with our current eager IR. A dimshuffle of a measurable component would be converted into a MeasurableDimShuffle eventually (since such rewrite can always be applied), and then the MeasurableMixture should behave exactly as if the component was any other MeasurableVariable.

This rewrite can always be applied later if it complicates other forms of pattern matching (which I don't think is yet the case).

The problems associated with lifting DimShuffles through RandomVariables are also very related to the problems involved with lifting all other Ops through RandomVariables, so, solutions to the latter can be solutions to the former, and lifting other Ops can have much clearer performance benefits (e.g. https://github.com/aesara-devs/aesara/issues/151).

I am not sure about the performance benefits of lifting a (non-reducing) operator over a function of multiple inputs. For instance for the DS case you usually would need to apply it to all inputs above the RV, instead of only applying it once to the output of the RV.

When we make the multivariate extensions mentioned here, RandomVariable.ndim_supp and RandomVariable.ndim_params information will probably be needed, but those are properties we could and should be adding to our MeasurableVariable interface anyway.

This makes a lot of sense. For instance, here I had to guess what the ndim_supp of a MeasurableVariable was by how many dims were lost, but it would have been much better to have that information in the IR op level: https://github.com/aesara-devs/aeppl/blob/394dc72a536f137fbeded74a196769baba807dc9/aeppl/tensor.py#L124

brandonwillard commented 2 years ago

I don't see why this would be the case with our current eager IR. A dimshuffle of a measurable component would be converted into a MeasurableDimShuffle eventually (since such rewrite can always be applied), and then the MeasurableMixture should behave exactly as if the component was any other MeasurableVariable.

Yes, and then we're locked into the current eager approach. This is the kind of coupling between design decisions that constricts our possibilities going forward. At the same time, MeasurableDimShuffles would continue to complicate rewrites on our AePPL IR, as regular DimShuffles do.

I am not sure about the performance benefits of lifting a (non-reducing) operator over a function of multiple inputs. For instance for the DS case you usually would need to apply it to all inputs above the RV, instead of only applying it once to the output of the RV.

Just like the example to which I linked, Subtensor lifting can reduce/remove unnecessary sampling. That's just one simple and explicitly performance enhancing lift. Remember, I wasn't just talking about DimShuffle lifts, but any lifts that are facilitated by figuring out the same underlying dimension-based logic for multivariate RandomVariables. Otherwise, lifting can reduce the number of run-time transposes by bringing Ops closer to constant inputs, prevent copies or poor performance due to array contiguity changes, etc.

In other words, it's a type of Op "mobility" that we need for performance-related rewriting in general.

If we can do things in a way that open us up to more features, approaches, and solves our immediate problem(s), then we need to consider such ways the most seriously.