Closed rlouf closed 2 years ago
Base: 95.18% // Head: 95.10% // Decreases project coverage by -0.08%
:warning:
Coverage data is based on head (
ab43194
) compared to base (54a5888
). Patch coverage: 100.00% of modified lines in pull request are covered.:exclamation: Current head ab43194 differs from pull request most recent head 3951ce5. Consider uploading reports for the commit 3951ce5 to get more accurate results
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
I think we should also make it as easy as possible to compile a logprob graph, as long as the interface is as non-ambiguous and flexible as the current one. Let’s consider a typical-sized model to illustrate the proposal:
import aesara
import aesara.tensor as at
col_1 = at.as_tensor([1., 2., 3., 5., 1])
col_2 = at.as_tensor([1., 2., 3., 5., 1])
srng = at.random.RandomStream(3)
intercept_rv = srng.normal(-4, 1., name="intercept")
beta_1_rv = srng.normal(0, 2.5, name="beta_1")
beta_2_rv = srng.normal(0, 2.5, name="beta_2")
likelihood = at.sigmoid(
intercept_rv
+ beta_1_rv * col_1
+ beta_2_rv * col_2
)
logit_rv = srng.bernoulli(likelihood, name="logit")
To compile the joint_logprob
graph, one currently needs to do the following:
import aeppl
intercept_vv = intercept_rv.clone()
beta_1_rv = beta_1_rv.clone()
beta_2_rv = beta_1_rv.clone()
logit_vv = logit_rv.clone()
logprob = aeppl.joint_logprob({
intercept_rv: intercept_vv,
beta_1_rv: beta_1_vv,
beta_2_rv: beta_2_vv,
logit_rv: logit_vv,
})
logrob_fn = aesara.function(values, logprob)
Having to create the value variables manually is cumbersome. We can easily improve this by making joint_logprob
create the value variables and return them as a list. Compilation is now straightforward without any loss in generality:
import aeppl
logprob, values = aeppl.joint_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
logit_rv,
)
print(values)
# [intercept_vv, beta_1_vv, beta_2,_vv, logit_vv]
logrob_fn = aesara.function(values, logprob)
For conditional_logprob
we need a bit more flexibility since we only need to specify the parents of a given variable when compiling its conditional log-probability, so I suggest to return the value variables in a map indexed by the corresponding random variables:
import aeppl
logprobs, values = aeppl.conditional_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
logit_rv,
)
print(logprobs)
# {intercept_rv: intercept_logprob, beta_1_rv: beta_1_logprob, beta_2_rv: beta_2,_logprob, logit_rv: logit_logprob}
print(values)
# {intercept_rv: intercept_vv, beta_1_rv: beta_1_vv, beta_2_rv: beta_2,_vv, logit_rv: logit_vv}
logrob_intercept_fn = aesara.function([], logprobs[intercept_rv])
logrob_logit_fn = aesara.function(values.values(), logprobs[logit_rv])
We need to consider the scenario where the value that some RVs take is known at construction time. In this case, we know the values that logit_rv
takes:
position = at.as_tensor([0, 1, 1, 1, 1])
logprob, values = aeppl.joint_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
realized={logit_rv: position},
)
logrob_fn = aesara.function(values, logprob)
This is good middle compromise between the explicit but cumbersome conditional_logprob({sigma2_rv: sigma2_vv, Y_rv: Y_data})
and the restrictive joint_logprob(sigma2_rv, y_rv)
which requires to pass the data to the function at every iteration. The term realized
is borrowed from measure theory, but we could also use the more widely-adopted observed
.
Anyway, these changes are fairly simple to implement and this could be done in this PR.
Now that I have implemented this interface in the last commit, I have a little experience using it given many tests needed to be updated. Here are my suggestions and remarks.
First, I think that the second ouput of joint_logprob
and conditional_logprob
should be a list in both cases. It is always easier to expand the tuple than to get elements from a dictionary:
logprob, (intercept_vv, beta_1_vv, beta_2_vv) = aeppl.joint_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
realized={logit_rv: position},
)
This dictionary can always be built:
rvs_to_sample = [intercept_rv, beta_1_rv, beta_2_rv]
logprob, values = aeppl.joint_logprob(*rvs_to_samples, realized={logit_rv: position})
rv_to_values = {rv: vv for rv, vv in zip(rvs_to_sample, values)}
but I expect this need to arise much less often than the need for convenience that the list
output provides.
I don't think we should special case the return values when a single random variable is passed to joint_logprob
or conditional_logprob
:
logprob, (x_vv,) = joint_logprob(x_rv)
logprobs, (x_vv,) = conditional_logprob(x_rv)
print(logprobs)
# {x_rv: x_rv_logprob}
I can currently make the transforms interface work in a roundabout way by passing the corresponding random variables as "realized":
X_rv = at.random.halfnormal(0, 3, name="X")
x_vv = X_rv.clone()
x_vv.name = "x"
transform_rewrite = TransformValuesRewrite({x_vv: DEFAULT_TRANSFORM})
tr_logp, _ = joint_logprob(
realized={X_rv: x_vv},
extra_rewrites=transform_rewrite
)
I am not sure whether it is possible yet, but we could circumvent this by associating the transform with the random variables:
X_rv = at.random.halfnormal(0, 3, name="X")
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})
tr_logp, (x_vv,) = joint_logprob(X_rv, extra_rewrites=transform_rewrite)
Which imo makes more sense anyway.
Current status:
test_value
that appear to not be computed, but I do not understand what change provoked this since they're all very benign. If someone well versed in test values want to take a look, I could do with the helptransform
interface described above.First, I think that the second ouput of
joint_logprob
andconditional_logprob
should be a list in both cases. It is always easier to expand the tuple than to get elements from a dictionary:
Before I forget to mention this again: we could always return a FunctionGraph
object that encapsulates the results.
I don't think we should special case the return values when a single random variable is passed to
joint_logprob
orconditional_logprob
:
Yeah, definitely not; that kind of stuff only ever complicates the typing and interfaces involved.
I am not sure whether it is possible yet, but we could circumvent this by associating the transform with the random variables:
Random variables and their value variables should be one-to-one (most of the time, at least), no?
Random variables and their value variables should be one-to-one (most of the time, at least), no?
You're right, it's more about how much time it will take than whether it is possible or not.
It would be nice too in terms of interface because afaik we are transforming the random variables, ie applying the transformation to the result of these functions.
Before I forget to mention this again: we could always return a
FunctionGraph
object that encapsulates the results.
What would be the benefit here ?
Before I forget to mention this again: we could always return a
FunctionGraph
object that encapsulates the results.What would be the benefit here ?
It's an establish Aesara collection type that connects inputs and outputs, and, since we use them at different stages in the processes between all of the AePPL interface functions, it's worth considering an approach that uses them. They can also hold meta information that a caller or the next interface function may want.
Just to be sure, are you suggesting we return?
import aeppl
logprob = aeppl.joint_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
logit_rv,
)
print(logprob.inputs)
# [intercept_vv, beta_1_vv, beta_2,_vv, logit_vv]
logprob_fn = aesara.function(logprob.inputs, logprob.outputs)
And
import aeppl
logprobs = aeppl.conditional_logprob(
intercept_rv,
beta_1_rv,
beta_2_rv,
logit_rv,
)
print(logprobs)
# {
# intercept_rv: intercept_logprob_fg,
# beta_1_rv: beta_1_logprob_fg,
# beta_2_rv: beta_2_logprob_fg,
# logit_vv: logit_logprob_fg
# }
logp_beta_1 = logprobs[beta_1_rv]
logprob_fn = aesara.function(logp_beta_1.inputs, logp_beta_1.outputs)
I don't see why not! It contains more information than the graphs we currently return, looks cleaner and simplifies compilation.
It is also immediately useful: callers currently need to know which value variables to use as an input to, say, beta_1_logprob
. A FunctionGraph
would already contain that information.
We could return sampling steps in AeMCMC
as FunctionGraph
as well.
I have modified joint_logprob
as to return a FunctionGraph
locally. After using it I am not completely sold on the idea; to build the FunctionGraph
we need to provide all the inputs to the logprob graphs. This undeniably makes compilation easier:
import aesara
import aesara.tensor as at
import aeppl
M = at.matrix("M")
x = at.random.normal(0, 1, size=M.shape[1], name="X")
y = at.random.normal(M.dot(x), 1, name="Y")
logp = aeppl.joint_logprob(x, y)
print(logp.inputs)
# (M, X_vv, Y_vv)
logp_fn = aesara.function(logp.inputs, logp.ouputs)
But puts pressure on the internals of downstream users, such as AeMCMC, which only need to know which are the value variables whose value we want to update. All that for a very relative gain over the more explicit:
logp, value_variables = aeppl.joint_logprob(x, y)
print(value_variables)
# (X_vv, Y_vv)
logp_fn = aesara.function([M] + value_variables, logp)
Of course we could attach a Feature
that holds the value variables and which could be used downstream callers. This adds a small cost, which would be outweighed if we can, down the line, need to target this region of the graph with logprob-specific rewrites. Still, I would hold off until that becomes useful/necessary;
An intermediate approach may be return an OpFromGraph
along with the value variables. It not only requires the user to be more explicit when compiling the function (and set the order of the compiled function's parameters themselves) which is less error-prone, but also allows to retrieve the inputs necessary to compute the individual conditional loprob values with logprob.owner.inputs
.
To summarize, let us look at the user interface with the three possibilities outlined above. At 23f2336:
import aeppl
import aesara
import aesara.tensor as at
srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")
logprob, value_variables = aeppl.joint_logprob(x, y)
fn = aesara.function([M] + value_variables, logprob)
# We have to infer from the model graph which variables
# are needed to compile the conditional logprobs, or
# use `aesara.graph.basic.graph_inputs`.
logprobs, (x_vv, y_vv) = aeppl.conditional_logprob(x, y)
fn = aesara.function([M, x_vv], logprob)
fn = aesara.function([M, x_vv, y_vv], logprob)
Using FunctionGraph
\s:
import aeppl
import aesara
import aesara.tensor as at
srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")
logprob = aeppl.joint_logprob(x, y)
fn = aesara.function(logprob.inputs, logprob.outputs)
# We need to check the order of inputs before calling the
# compiled function.
#
# We don't know which are the value variables created by
# `joint_logprob`.
print(logprob.inputs)
# [x_vv, y_vv, M]
With OpFromGraph
:
import aeppl
import aesara
import aesara.tensor as at
srng = at.random.RandomStream(0)
M = at.matrix("M")
x = srng.normal(0, 1, size=M.shape[1], name="X")
y = srng.normal(M.dot(x), 1, name="Y")
logprob, value_variables = aeppl.joint_logprob(x, y)
fn = aesara.function([M] + value_variables, logprob)
logprobs, (x_vv, y_vv) = aeppl.conditional_logprob(x, y)
print(logprobs[x].owner.inputs)
# [X_vv]
fn = aesara.function([x_vv], logprobs[x])
print(logprobs[y].owner.inputs)
# [M, X_vv, Y_vv]
# We can use any order for the inputs of the compiled function
fn = aesara.function([x_vv, y_vv, M], logprobs[y])
aesara.dprint(logprob)
# LogProbability{inline=True} [id A]
# |X_vv [id B]
# |Y_vv [id C]
# |M [id D]
# Inner graphs:
# LogProbability{inline=True} [id A]
# >Sum{acc_dtype=float64} [id E]
# > |MakeVector{dtype='float64'} [id F]
# > |Sum{acc_dtype=float64} [id G]
# > | |Check{sigma > 0} [id H] 'X_logprob'
# > | |Elemwise{sub,no_inplace} [id I]
# > | | |Elemwise{sub,no_inplace} [id J]
# > | | | |Elemwise{mul,no_inplace} [id K]
# > | | | | |InplaceDimShuffle{x} [id L]
# > | | | | | |TensorConstant{-0.5} [id M]
# > | | | | |Elemwise{pow,no_inplace} [id N]
# > | | | | |Elemwise{true_div,no_inplace} [id O]
# > | | | | | |Elemwise{sub,no_inplace} [id P]
# > | | | | | | |*0-<TensorType(float64, (None,))> [id Q]
# > | | | | | | |InplaceDimShuffle{x} [id R]
# > | | | | | | |TensorConstant{0} [id S]
# > | | | | | |InplaceDimShuffle{x} [id T]
# > | | | | | |TensorConstant{1} [id U]
# > | | | | |InplaceDimShuffle{x} [id V]
# > | | | | |TensorConstant{2} [id W]
# > | | | |InplaceDimShuffle{x} [id X]
# > | | | |Elemwise{log,no_inplace} [id Y]
# > | | | |Elemwise{sqrt,no_inplace} [id Z]
# > | | | |TensorConstant{6.283185307179586} [id BA]
# > | | |InplaceDimShuffle{x} [id BB]
# > | | |Elemwise{log,no_inplace} [id BC]
# > | | |TensorConstant{1} [id U]
# > | |All [id BD]
# > | |Elemwise{gt,no_inplace} [id BE]
# > | |TensorConstant{1} [id U]
# > | |TensorConstant{0.0} [id BF]
# > |Sum{acc_dtype=float64} [id BG]
# > |Check{sigma > 0} [id BH] 'Y_logprob'
# > |Elemwise{sub,no_inplace} [id BI]
# > | |Elemwise{sub,no_inplace} [id BJ]
# > | | |Elemwise{mul,no_inplace} [id BK]
# > | | | |InplaceDimShuffle{x} [id BL]
# > | | | | |TensorConstant{-0.5} [id BM]
# > | | | |Elemwise{pow,no_inplace} [id BN]
# > | | | |Elemwise{true_div,no_inplace} [id BO]
# > | | | | |Elemwise{sub,no_inplace} [id BP]
# > | | | | | |*1-<TensorType(float64, (None,))> [id BQ]
# > | | | | | |dot [id BR]
# > | | | | | |*2-<TensorType(float64, (None, None))> [id BS]
# > | | | | | |*0-<TensorType(float64, (None,))> [id Q]
# > | | | | |InplaceDimShuffle{x} [id BT]
# > | | | | |TensorConstant{1} [id U]
# > | | | |InplaceDimShuffle{x} [id BU]
# > | | | |TensorConstant{2} [id BV]
# > | | |InplaceDimShuffle{x} [id BW]
# > | | |Elemwise{log,no_inplace} [id BX]
# > | | |Elemwise{sqrt,no_inplace} [id BY]
# > | | |TensorConstant{6.283185307179586} [id BZ]
# > | |InplaceDimShuffle{x} [id CA]
# > | |Elemwise{log,no_inplace} [id CB]
# > | |TensorConstant{1} [id U]
# > |All [id CC]
# > |Elemwise{gt,no_inplace} [id CD]
# > |TensorConstant{1} [id U]
# > |TensorConstant{0.0} [id CE]
The difference between the current solution and the one using OpFromGraph
s is mostly important for the libraries' internals, and more specifically the rewrites. I would keep the current interface and reconsider wrapping these graphs in a separate PR. Especially given that the way we encapsulated subgraphs may change in the near future: https://github.com/aesara-devs/aesara/issues/1287
Why were the changes in "Remove dependency on specific error messages in tests" needed? It's definitely preferable to have that kind of test specificity in these cases.
It was in a separate commit to remove it later as I thought I'd be iterating on the error messages as well. Removed it.
To make all of this coherent I still have to associate transforms with the random variables as in:
import aeppl
from aeppl.transforms import TransformValuesRewrite, DEFAULT_TRANSFORM
X_rv = at.random.halfnormal(0, 3, name="X")
transform_rewrite = TransformValuesRewrite({X_rv: DEFAULT_TRANSFORM})
tr_logp, (x_vv,) = aeppl.joint_logprob(X_rv, extra_rewrites=transform_rewrite)
The tests pass, but I've noticed something while going over the tests that needs to be adressed:
import aesara
import aesara.tensor as at
srng = at.random.RandomStream()
x_rv = srng.dirichlet(np.array([0.7, 0.3]))
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})
# `test_transformed_logprob` passes if I pass this as `realized`
x_value_var = at.tensor(x_rv.dtype, shape=(None,) * x_rv.ndim)
logprob, _ = aeppl.joint_logprob(realized={x_rv: x_value_var}, extra_rewrites=transform_rewrite)
# But it does not pass if I use the value variable produced here. The forward transform
# assumes the value variable is one-dimensional (which is expected), a shape error
# follows.
logprob, (x_value_var,) = aeppl.joint_logprob(x_rv, extra_rewrites=transform_rewrite)
The solution seems to me to provide shape (and dtype) information to the value variables created inside conditional_logprob
. Which would not be a bad thing, actually.
The solution seems to me to provide shape (and dtype) information to the value variables created inside
conditional_logprob
. Which would not be a bad thing, actually.
What are the TensorType
s in each case?
In the first case TensorType(float64, (None, None))
and the second TensorType(float64, (2, 2))
.
The issue comes from the fact that the x_value_var
passed in the current interface must be in the transformed space, thus of the same type as transform.forward(x_rv)
. In this PR we initialize it with the same type as x_rv
, so I need to change that. The difference only shows up for SimplexTransform
.
I fixed this by applying the forward transform to the copy of the corresponding random variable when initializing the value variables. This complicates the internals slightly but makes the interface less ambiguous.
If anything this shows that the changes proposed in #195 and #119 are necessary. The initialization would not be so complicated if, instead of passing a TransformValuesRewrite
instance we could condition directly on the transformed variable as in the code below:
import aeppl
import aesara.tensor as at
x_rv = at.random.normal(0, 1)
y = at.exp(x_rv)
logprob, (y_vv,) = aeppl.joint_logprob(y)
Exploration following #193. Different approach from #180. Naming conventions follow Koller & Friedman "Probabilistic Graphical Models: Principles and Techniques".
We can now get a dictionary that maps the original random variables of a model to their associated conditional log-probabilities using:
conditional_logprob
replacesfactorized_joint_logprob
. I also removed thesum
argument tojoint_logprob
; it is unclear to me what mathematical quantitysum=False
corresponds to.The choice of the general keyword realize instead of observed is intentional. The distinction shows up for instance when building the logprob for a MCMC step in a Metropolis-within-Gibbs sampler: one needs to condition on the values takes by the variables not targeted by the MCMC step, but we cannot really say that they're "observed".
Closes #162