aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Refactor the interface to construct logprob graphs #196

Closed rlouf closed 2 years ago

rlouf commented 2 years ago

Exploration following #193. Different approach from #180. Naming conventions follow Koller & Friedman "Probabilistic Graphical Models: Principles and Techniques".

We can now get a dictionary that maps the original random variables of a model to their associated conditional log-probabilities using:

import aeppl

logprobs, values = aeppl.conditional_logprob(x_rv, y_rv, realized={z_rv: z_vv})
print(logprobs.keys())
# [x_rv, y_rv]
print(values)
# [x_vv, y_vv]

conditional_logprob replaces factorized_joint_logprob. I also removed the sum argument to joint_logprob; it is unclear to me what mathematical quantity sum=False corresponds to.

The choice of the general keyword realize instead of observed is intentional. The distinction shows up for instance when building the logprob for a MCMC step in a Metropolis-within-Gibbs sampler: one needs to condition on the values takes by the variables not targeted by the MCMC step, but we cannot really say that they're "observed".

Closes #162

codecov[bot] commented 2 years ago

Codecov Report

Base: 95.18% // Head: 95.10% // Decreases project coverage by -0.08% :warning:

Coverage data is based on head (ab43194) compared to base (54a5888). Patch coverage: 100.00% of modified lines in pull request are covered.

:exclamation: Current head ab43194 differs from pull request most recent head 3951ce5. Consider uploading reports for the commit 3951ce5 to get more accurate results

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #196 +/- ## ========================================== - Coverage 95.18% 95.10% -0.09% ========================================== Files 12 12 Lines 1952 1960 +8 Branches 254 256 +2 ========================================== + Hits 1858 1864 +6 - Misses 52 54 +2 Partials 42 42 ``` | [Impacted Files](https://codecov.io/gh/aesara-devs/aeppl/pull/196?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [aeppl/joint\_logprob.py](https://codecov.io/gh/aesara-devs/aeppl/pull/196/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvam9pbnRfbG9ncHJvYi5weQ==) | `94.66% <100.00%> (+0.54%)` | :arrow_up: | | [aeppl/scan.py](https://codecov.io/gh/aesara-devs/aeppl/pull/196/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvc2Nhbi5weQ==) | `94.73% <100.00%> (ø)` | | | [aeppl/transforms.py](https://codecov.io/gh/aesara-devs/aeppl/pull/196/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvdHJhbnNmb3Jtcy5weQ==) | `96.44% <100.00%> (+0.01%)` | :arrow_up: | | [aeppl/mixture.py](https://codecov.io/gh/aesara-devs/aeppl/pull/196/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVwcGwvbWl4dHVyZS5weQ==) | `96.59% <0.00%> (-1.14%)` | :arrow_down: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

rlouf commented 2 years ago

I think we should also make it as easy as possible to compile a logprob graph, as long as the interface is as non-ambiguous and flexible as the current one. Let’s consider a typical-sized model to illustrate the proposal:

import aesara
import aesara.tensor as at

col_1 = at.as_tensor([1., 2., 3., 5., 1])
col_2 = at.as_tensor([1., 2., 3., 5., 1])

srng = at.random.RandomStream(3)

intercept_rv = srng.normal(-4, 1., name="intercept")
beta_1_rv = srng.normal(0, 2.5, name="beta_1")
beta_2_rv = srng.normal(0, 2.5, name="beta_2")

likelihood = at.sigmoid(
    intercept_rv
    + beta_1_rv * col_1
    + beta_2_rv * col_2
)

logit_rv = srng.bernoulli(likelihood, name="logit")

To compile the joint_logprob graph, one currently needs to do the following:

import aeppl

intercept_vv = intercept_rv.clone()
beta_1_rv = beta_1_rv.clone()
beta_2_rv = beta_1_rv.clone()
logit_vv = logit_rv.clone()

logprob = aeppl.joint_logprob({
    intercept_rv: intercept_vv,
    beta_1_rv: beta_1_vv,
    beta_2_rv: beta_2_vv,
    logit_rv: logit_vv,
})

logrob_fn = aesara.function(values, logprob)

Having to create the value variables manually is cumbersome. We can easily improve this by making joint_logprob create the value variables and return them as a list. Compilation is now straightforward without any loss in generality:

import aeppl

logprob, values = aeppl.joint_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    logit_rv,
)

print(values)
# [intercept_vv, beta_1_vv, beta_2,_vv, logit_vv]

logrob_fn = aesara.function(values, logprob)

For conditional_logprob we need a bit more flexibility since we only need to specify the parents of a given variable when compiling its conditional log-probability, so I suggest to return the value variables in a map indexed by the corresponding random variables:

import aeppl

logprobs, values = aeppl.conditional_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    logit_rv,
)

print(logprobs)
# {intercept_rv: intercept_logprob, beta_1_rv: beta_1_logprob, beta_2_rv: beta_2,_logprob, logit_rv: logit_logprob}
print(values)
# {intercept_rv: intercept_vv, beta_1_rv: beta_1_vv, beta_2_rv: beta_2,_vv, logit_rv: logit_vv}

logrob_intercept_fn = aesara.function([], logprobs[intercept_rv])
logrob_logit_fn = aesara.function(values.values(), logprobs[logit_rv])

We need to consider the scenario where the value that some RVs take is known at construction time. In this case, we know the values that logit_rv takes:

position = at.as_tensor([0, 1, 1, 1, 1])

logprob, values = aeppl.joint_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    realized={logit_rv: position},
)

logrob_fn = aesara.function(values, logprob)

This is good middle compromise between the explicit but cumbersome conditional_logprob({sigma2_rv: sigma2_vv, Y_rv: Y_data}) and the restrictive joint_logprob(sigma2_rv, y_rv) which requires to pass the data to the function at every iteration. The term realized is borrowed from measure theory, but we could also use the more widely-adopted observed.

Anyway, these changes are fairly simple to implement and this could be done in this PR.

rlouf commented 2 years ago

Now that I have implemented this interface in the last commit, I have a little experience using it given many tests needed to be updated. Here are my suggestions and remarks.

Always return a list of value variables

First, I think that the second ouput of joint_logprob and conditional_logprob should be a list in both cases. It is always easier to expand the tuple than to get elements from a dictionary:

logprob, (intercept_vv, beta_1_vv, beta_2_vv) = aeppl.joint_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    realized={logit_rv: position},
)

This dictionary can always be built:

rvs_to_sample = [intercept_rv, beta_1_rv, beta_2_rv]
logprob, values = aeppl.joint_logprob(*rvs_to_samples, realized={logit_rv: position})
rv_to_values = {rv: vv for rv, vv in zip(rvs_to_sample, values)}

but I expect this need to arise much less often than the need for convenience that the list output provides.

Single random variable

I don't think we should special case the return values when a single random variable is passed to joint_logprob or conditional_logprob:

logprob, (x_vv,) = joint_logprob(x_rv)

logprobs, (x_vv,) = conditional_logprob(x_rv)
print(logprobs)
# {x_rv: x_rv_logprob}

The transforms interface is in the way

I can currently make the transforms interface work in a roundabout way by passing the corresponding random variables as "realized":

X_rv = at.random.halfnormal(0, 3, name="X")
x_vv = X_rv.clone()
x_vv.name = "x"

transform_rewrite = TransformValuesRewrite({x_vv: DEFAULT_TRANSFORM})
tr_logp, _ = joint_logprob(
    realized={X_rv: x_vv},
    extra_rewrites=transform_rewrite
)

I am not sure whether it is possible yet, but we could circumvent this by associating the transform with the random variables:

X_rv = at.random.halfnormal(0, 3, name="X")

transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})
tr_logp, (x_vv,) = joint_logprob(X_rv,  extra_rewrites=transform_rewrite)

Which imo makes more sense anyway.

rlouf commented 2 years ago

Current status:

brandonwillard commented 2 years ago

First, I think that the second ouput of joint_logprob and conditional_logprob should be a list in both cases. It is always easier to expand the tuple than to get elements from a dictionary:

Before I forget to mention this again: we could always return a FunctionGraph object that encapsulates the results.

brandonwillard commented 2 years ago

I don't think we should special case the return values when a single random variable is passed to joint_logprob or conditional_logprob:

Yeah, definitely not; that kind of stuff only ever complicates the typing and interfaces involved.

brandonwillard commented 2 years ago

I am not sure whether it is possible yet, but we could circumvent this by associating the transform with the random variables:

Random variables and their value variables should be one-to-one (most of the time, at least), no?

rlouf commented 2 years ago

Random variables and their value variables should be one-to-one (most of the time, at least), no?

You're right, it's more about how much time it will take than whether it is possible or not.

It would be nice too in terms of interface because afaik we are transforming the random variables, ie applying the transformation to the result of these functions.

rlouf commented 2 years ago

Before I forget to mention this again: we could always return a FunctionGraph object that encapsulates the results.

What would be the benefit here ?

brandonwillard commented 2 years ago

Before I forget to mention this again: we could always return a FunctionGraph object that encapsulates the results.

What would be the benefit here ?

It's an establish Aesara collection type that connects inputs and outputs, and, since we use them at different stages in the processes between all of the AePPL interface functions, it's worth considering an approach that uses them. They can also hold meta information that a caller or the next interface function may want.

rlouf commented 2 years ago

Just to be sure, are you suggesting we return?

import aeppl

logprob = aeppl.joint_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    logit_rv,
)

print(logprob.inputs)
# [intercept_vv, beta_1_vv, beta_2,_vv, logit_vv]

logprob_fn = aesara.function(logprob.inputs, logprob.outputs)

And

import aeppl

logprobs = aeppl.conditional_logprob(
    intercept_rv,
    beta_1_rv,
    beta_2_rv,
    logit_rv,
)

print(logprobs)
# {
#    intercept_rv: intercept_logprob_fg,
#    beta_1_rv: beta_1_logprob_fg,
#    beta_2_rv: beta_2_logprob_fg,
#    logit_vv: logit_logprob_fg
# }

logp_beta_1 = logprobs[beta_1_rv]
logprob_fn = aesara.function(logp_beta_1.inputs, logp_beta_1.outputs)

I don't see why not! It contains more information than the graphs we currently return, looks cleaner and simplifies compilation.

It is also immediately useful: callers currently need to know which value variables to use as an input to, say, beta_1_logprob. A FunctionGraph would already contain that information.

We could return sampling steps in AeMCMC as FunctionGraph as well.

rlouf commented 2 years ago

I have modified joint_logprob as to return a FunctionGraph locally. After using it I am not completely sold on the idea; to build the FunctionGraph we need to provide all the inputs to the logprob graphs. This undeniably makes compilation easier:

import aesara
import aesara.tensor as at

import aeppl

M = at.matrix("M")
x = at.random.normal(0, 1, size=M.shape[1], name="X")
y = at.random.normal(M.dot(x), 1, name="Y")

logp = aeppl.joint_logprob(x, y)
print(logp.inputs)
# (M, X_vv, Y_vv)

logp_fn = aesara.function(logp.inputs, logp.ouputs)

But puts pressure on the internals of downstream users, such as AeMCMC, which only need to know which are the value variables whose value we want to update. All that for a very relative gain over the more explicit:

logp, value_variables = aeppl.joint_logprob(x, y)
print(value_variables)
# (X_vv, Y_vv)

logp_fn = aesara.function([M] + value_variables, logp)

Of course we could attach a Feature that holds the value variables and which could be used downstream callers. This adds a small cost, which would be outweighed if we can, down the line, need to target this region of the graph with logprob-specific rewrites. Still, I would hold off until that becomes useful/necessary;

An intermediate approach may be return an OpFromGraph along with the value variables. It not only requires the user to be more explicit when compiling the function (and set the order of the compiled function's parameters themselves) which is less error-prone, but also allows to retrieve the inputs necessary to compute the individual conditional loprob values with logprob.owner.inputs.

rlouf commented 2 years ago

To summarize, let us look at the user interface with the three possibilities outlined above. At 23f2336:

import aeppl
import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")

logprob, value_variables = aeppl.joint_logprob(x, y)
fn = aesara.function([M] + value_variables, logprob)

# We have to infer from the model graph which variables
# are needed to compile the conditional logprobs, or
# use `aesara.graph.basic.graph_inputs`.
logprobs, (x_vv, y_vv) = aeppl.conditional_logprob(x, y)
fn = aesara.function([M, x_vv], logprob)
fn = aesara.function([M, x_vv, y_vv], logprob)

Using FunctionGraph\s:

import aeppl
import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")

logprob = aeppl.joint_logprob(x, y)
fn = aesara.function(logprob.inputs, logprob.outputs)
# We need to check the order of inputs before calling the
# compiled function.
#
# We don't know which are the value variables created by
# `joint_logprob`.
print(logprob.inputs)
# [x_vv, y_vv, M]

With OpFromGraph:

import aeppl
import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")

logprob, value_variables = aeppl.joint_logprob(x, y)
fn = aesara.function([M] + value_variables, logprob)

logprobs, (x_vv, y_vv) = aeppl.conditional_logprob(x, y)
print(logprobs[x].owner.inputs)
# [X_vv]
fn = aesara.function([x_vv], logprobs[x])

print(logprobs[y].owner.inputs)
# [M, X_vv, Y_vv]
# We can use any order for the inputs of the compiled function
fn = aesara.function([x_vv, y_vv, M], logprobs[y])

aesara.dprint(logprob)
# LogProbability{inline=True} [id A]
#  |X_vv [id B]
#  |Y_vv [id C]
#  |M [id D]

# Inner graphs:

# LogProbability{inline=True} [id A]
#  >Sum{acc_dtype=float64} [id E]
#  > |MakeVector{dtype='float64'} [id F]
#  >   |Sum{acc_dtype=float64} [id G]
#  >   | |Check{sigma > 0} [id H] 'X_logprob'
#  >   |   |Elemwise{sub,no_inplace} [id I]
#  >   |   | |Elemwise{sub,no_inplace} [id J]
#  >   |   | | |Elemwise{mul,no_inplace} [id K]
#  >   |   | | | |InplaceDimShuffle{x} [id L]
#  >   |   | | | | |TensorConstant{-0.5} [id M]
#  >   |   | | | |Elemwise{pow,no_inplace} [id N]
#  >   |   | | |   |Elemwise{true_div,no_inplace} [id O]
#  >   |   | | |   | |Elemwise{sub,no_inplace} [id P]
#  >   |   | | |   | | |*0-<TensorType(float64, (None,))> [id Q]
#  >   |   | | |   | | |InplaceDimShuffle{x} [id R]
#  >   |   | | |   | |   |TensorConstant{0} [id S]
#  >   |   | | |   | |InplaceDimShuffle{x} [id T]
#  >   |   | | |   |   |TensorConstant{1} [id U]
#  >   |   | | |   |InplaceDimShuffle{x} [id V]
#  >   |   | | |     |TensorConstant{2} [id W]
#  >   |   | | |InplaceDimShuffle{x} [id X]
#  >   |   | |   |Elemwise{log,no_inplace} [id Y]
#  >   |   | |     |Elemwise{sqrt,no_inplace} [id Z]
#  >   |   | |       |TensorConstant{6.283185307179586} [id BA]
#  >   |   | |InplaceDimShuffle{x} [id BB]
#  >   |   |   |Elemwise{log,no_inplace} [id BC]
#  >   |   |     |TensorConstant{1} [id U]
#  >   |   |All [id BD]
#  >   |     |Elemwise{gt,no_inplace} [id BE]
#  >   |       |TensorConstant{1} [id U]
#  >   |       |TensorConstant{0.0} [id BF]
#  >   |Sum{acc_dtype=float64} [id BG]
#  >     |Check{sigma > 0} [id BH] 'Y_logprob'
#  >       |Elemwise{sub,no_inplace} [id BI]
#  >       | |Elemwise{sub,no_inplace} [id BJ]
#  >       | | |Elemwise{mul,no_inplace} [id BK]
#  >       | | | |InplaceDimShuffle{x} [id BL]
#  >       | | | | |TensorConstant{-0.5} [id BM]
#  >       | | | |Elemwise{pow,no_inplace} [id BN]
#  >       | | |   |Elemwise{true_div,no_inplace} [id BO]
#  >       | | |   | |Elemwise{sub,no_inplace} [id BP]
#  >       | | |   | | |*1-<TensorType(float64, (None,))> [id BQ]
#  >       | | |   | | |dot [id BR]
#  >       | | |   | |   |*2-<TensorType(float64, (None, None))> [id BS]
#  >       | | |   | |   |*0-<TensorType(float64, (None,))> [id Q]
#  >       | | |   | |InplaceDimShuffle{x} [id BT]
#  >       | | |   |   |TensorConstant{1} [id U]
#  >       | | |   |InplaceDimShuffle{x} [id BU]
#  >       | | |     |TensorConstant{2} [id BV]
#  >       | | |InplaceDimShuffle{x} [id BW]
#  >       | |   |Elemwise{log,no_inplace} [id BX]
#  >       | |     |Elemwise{sqrt,no_inplace} [id BY]
#  >       | |       |TensorConstant{6.283185307179586} [id BZ]
#  >       | |InplaceDimShuffle{x} [id CA]
#  >       |   |Elemwise{log,no_inplace} [id CB]
#  >       |     |TensorConstant{1} [id U]
#  >       |All [id CC]
#  >         |Elemwise{gt,no_inplace} [id CD]
#  >           |TensorConstant{1} [id U]
#  >           |TensorConstant{0.0} [id CE]

The difference between the current solution and the one using OpFromGraphs is mostly important for the libraries' internals, and more specifically the rewrites. I would keep the current interface and reconsider wrapping these graphs in a separate PR. Especially given that the way we encapsulated subgraphs may change in the near future: https://github.com/aesara-devs/aesara/issues/1287

rlouf commented 2 years ago

Why were the changes in "Remove dependency on specific error messages in tests" needed? It's definitely preferable to have that kind of test specificity in these cases.

It was in a separate commit to remove it later as I thought I'd be iterating on the error messages as well. Removed it.

To make all of this coherent I still have to associate transforms with the random variables as in:

import aeppl
from aeppl.transforms import TransformValuesRewrite, DEFAULT_TRANSFORM

X_rv = at.random.halfnormal(0, 3, name="X")

transform_rewrite = TransformValuesRewrite({X_rv: DEFAULT_TRANSFORM})
tr_logp, (x_vv,) = aeppl.joint_logprob(X_rv,  extra_rewrites=transform_rewrite)
rlouf commented 2 years ago

The tests pass, but I've noticed something while going over the tests that needs to be adressed:

import aesara
import aesara.tensor as at

srng = at.random.RandomStream()
x_rv = srng.dirichlet(np.array([0.7, 0.3]))
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})

# `test_transformed_logprob` passes if I pass this as `realized` 
x_value_var = at.tensor(x_rv.dtype, shape=(None,) * x_rv.ndim)
logprob, _ = aeppl.joint_logprob(realized={x_rv: x_value_var}, extra_rewrites=transform_rewrite)

# But it does not pass if I use the value variable produced here. The forward transform
# assumes the value variable is one-dimensional (which is expected), a shape error
# follows.
logprob, (x_value_var,) = aeppl.joint_logprob(x_rv, extra_rewrites=transform_rewrite)

The solution seems to me to provide shape (and dtype) information to the value variables created inside conditional_logprob. Which would not be a bad thing, actually.

brandonwillard commented 2 years ago

The solution seems to me to provide shape (and dtype) information to the value variables created inside conditional_logprob. Which would not be a bad thing, actually.

What are the TensorTypes in each case?

rlouf commented 2 years ago

In the first case TensorType(float64, (None, None)) and the second TensorType(float64, (2, 2)).

The issue comes from the fact that the x_value_var passed in the current interface must be in the transformed space, thus of the same type as transform.forward(x_rv). In this PR we initialize it with the same type as x_rv, so I need to change that. The difference only shows up for SimplexTransform.

rlouf commented 2 years ago

I fixed this by applying the forward transform to the copy of the corresponding random variable when initializing the value variables. This complicates the internals slightly but makes the interface less ambiguous.

If anything this shows that the changes proposed in #195 and #119 are necessary. The initialization would not be so complicated if, instead of passing a TransformValuesRewrite instance we could condition directly on the transformed variable as in the code below:

import aeppl
import aesara.tensor as at

x_rv = at.random.normal(0, 1)
y = at.exp(x_rv)

logprob, (y_vv,) = aeppl.joint_logprob(y)