aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

`ValuedVariable` still present in the logdensity graph #222

Closed rlouf closed 1 year ago

rlouf commented 1 year ago

The following model still has a ValuedVariable in its logdensity:

import aeppl
import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)

X_at = at.matrix("X")
tau_rv = srng.halfcauchy(1)
beta_rv = srng.normal(0, tau_rv, size=X_at.shape[-1])

eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p)

logdensity, vvs = aeppl.joint_logprob(Y_rv, beta_rv, tau_rv)

Indeed:

aesara.dprint(logdensity)
# Sum{acc_dtype=float64} [id A]
#  |MakeVector{dtype='float64'} [id B]
#    |Sum{acc_dtype=float64} [id C]
#    | |Check{0 <= p <= 1} [id D]
#    |   |Elemwise{switch,no_inplace} [id E]
#    |   | |Elemwise{and_,no_inplace} [id F]
#    |   | | |Elemwise{le,no_inplace} [id G]
#    |   | | | |InplaceDimShuffle{x} [id H]
#    |   | | | | |TensorConstant{0} [id I]
#    |   | | | |<TensorType(int64, (?,))> [id J]
#    |   | | |Elemwise{le,no_inplace} [id K]
#    |   | |   |<TensorType(int64, (?,))> [id J]
#    |   | |   |InplaceDimShuffle{x} [id L]
#    |   | |     |TensorConstant{1} [id M]
#    |   | |Elemwise{switch,no_inplace} [id N]
#    |   | | |<TensorType(int64, (?,))> [id J]
#    |   | | |Elemwise{log,no_inplace} [id O]
#    |   | | | |Elemwise{sigmoid,no_inplace} [id P]
#    |   | | |   |Elemwise{mul,no_inplace} [id Q]
#    |   | | |     |TensorConstant{(1,) of -1.0} [id R]
#    |   | | |     |dot [id S]
#    |   | | |       |X [id T]
#    |   | | |       |ValuedVariable [id U]
#    |   | | |         |normal_rv{0, (0, 0), floatX, False}.1 [id V]
#    |   | | |         | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F9742B3A3C0>) [id W]
#    |   | | |         | |InplaceDimShuffle{x} [id X]
#    |   | | |         | | |Shape_i{1} [id Y]
#    |   | | |         | |   |X [id T]
#    |   | | |         | |TensorConstant{11} [id Z]
#    |   | | |         | |TensorConstant{0} [id BA]
#    |   | | |         | |halfcauchy_rv{0, (0, 0), floatX, False}.1 [id BB]
#    |   | | |         |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F9742D27BA0>) [id BC]
#    |   | | |         |   |TensorConstant{[]} [id BD]
#    |   | | |         |   |TensorConstant{11} [id Z]
#    |   | | |         |   |TensorConstant{1} [id BE]
#    |   | | |         |   |TensorConstant{1.0} [id BF]
#    |   | | |         |<TensorType(float64, (?,))> [id BG]
#    |   | | |Elemwise{log,no_inplace} [id BH]
#    |   | |   |Elemwise{sub,no_inplace} [id BI]
#    |   | |     |InplaceDimShuffle{x} [id BJ]
#    |   | |     | |TensorConstant{1.0} [id BK]
#    |   | |     |Elemwise{sigmoid,no_inplace} [id P]
#    |   | |InplaceDimShuffle{x} [id BL]
#    |   |   |TensorConstant{-inf} [id BM]
#    |   |All [id BN]
#    |   | |Elemwise{le,no_inplace} [id BO]
#    |   |   |InplaceDimShuffle{x} [id BP]
#    |   |   | |TensorConstant{0.0} [id BQ]
#    |   |   |Elemwise{sigmoid,no_inplace} [id P]
#    |   |All [id BR]
#    |     |Elemwise{le,no_inplace} [id BS]
#    |       |Elemwise{sigmoid,no_inplace} [id P]
#    |       |InplaceDimShuffle{x} [id BT]
#    |         |TensorConstant{1.0} [id BU]
#    |Sum{acc_dtype=float64} [id BV]
#    | |Check{sigma > 0} [id BW]
#    |   |Elemwise{sub,no_inplace} [id BX]
#    |   | |Elemwise{sub,no_inplace} [id BY]
#    |   | | |Elemwise{mul,no_inplace} [id BZ]
#    |   | | | |InplaceDimShuffle{x} [id CA]
#    |   | | | | |TensorConstant{-0.5} [id CB]
#    |   | | | |Elemwise{pow,no_inplace} [id CC]
#    |   | | |   |Elemwise{true_div,no_inplace} [id CD]
#    |   | | |   | |Elemwise{sub,no_inplace} [id CE]
#    |   | | |   | | |<TensorType(float64, (?,))> [id BG]
#    |   | | |   | | |InplaceDimShuffle{x} [id CF]
#    |   | | |   | |   |TensorConstant{0} [id BA]
#    |   | | |   | |InplaceDimShuffle{x} [id CG]
#    |   | | |   |   |<TensorType(float64, ())> [id CH]
#    |   | | |   |InplaceDimShuffle{x} [id CI]
#    |   | | |     |TensorConstant{2} [id CJ]
#    |   | | |InplaceDimShuffle{x} [id CK]
#    |   | |   |Elemwise{log,no_inplace} [id CL]
#    |   | |     |Elemwise{sqrt,no_inplace} [id CM]
#    |   | |       |TensorConstant{6.283185307179586} [id CN]
#    |   | |InplaceDimShuffle{x} [id CO]
#    |   |   |Elemwise{log,no_inplace} [id CP]
#    |   |     |<TensorType(float64, ())> [id CH]
#    |   |All [id CQ]
#    |     |Elemwise{gt,no_inplace} [id CR]
#    |       |<TensorType(float64, ())> [id CH]
#    |       |TensorConstant{0.0} [id CS]
#    |Sum{acc_dtype=float64} [id CT]
#      |Elemwise{switch,no_inplace} [id CU]
#        |Elemwise{ge,no_inplace} [id CV]
#        | |<TensorType(float64, ())> [id CH]
#        | |TensorConstant{1} [id BE]
#        |Elemwise{add,no_inplace} [id CW]
#        | |Elemwise{log,no_inplace} [id CX]
#        | | |TensorConstant{2} [id CY]
#        | |Check{beta > 0} [id CZ]
#        |   |Elemwise{sub,no_inplace} [id DA]
#        |   | |Elemwise{sub,no_inplace} [id DB]
#        |   | | |Elemwise{neg,no_inplace} [id DC]
#        |   | | | |Elemwise{log,no_inplace} [id DD]
#        |   | | |   |TensorConstant{3.141592653589793} [id DE]
#        |   | | |Elemwise{log,no_inplace} [id DF]
#        |   | |   |TensorConstant{1.0} [id BF]
#        |   | |Elemwise{log1p,no_inplace} [id DG]
#        |   |   |Elemwise{pow,no_inplace} [id DH]
#        |   |     |Elemwise{true_div,no_inplace} [id DI]
#        |   |     | |Elemwise{sub,no_inplace} [id DJ]
#        |   |     | | |<TensorType(float64, ())> [id CH]
#        |   |     | | |TensorConstant{1} [id BE]
#        |   |     | |TensorConstant{1.0} [id BF]
#        |   |     |TensorConstant{2} [id DK]
#        |   |All [id DL]
#        |     |Elemwise{gt,no_inplace} [id DM]
#        |       |TensorConstant{1.0} [id BF]
#        |       |TensorConstant{0.0} [id DN]
#        |TensorConstant{-inf} [id DO]

Not sure if it helps, but there is no ValuedVariable in the following graph:

import aeppl
import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)

X_at = at.matrix("X")
beta_rv = srng.normal(0, 1., size=X_at.shape[-1])

eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p)

logdensity, vvs = aeppl.joint_logprob(Y_rv, beta_rv)
brandonwillard commented 1 year ago

Ah, yeah, this was something I wanted to address in the original PR; instead of https://github.com/aesara-devs/aeppl/blob/7890228715320c2d1ff6989ff522829b5d6411ba/aeppl/joint_logprob.py#L223-L226 we need a simple rewrite that universally removes the ValuedVariables.