Open adamhaber opened 5 years ago
Thanks for raising the issue! I agree completely - I have been discussing with some other users that sees similar need (e.g., https://twitter.com/ML_deep/status/1188387178507694081?s=20) @csuter @brianwa84 @jvdillon
I'd love it. Should work with all flavors of JD, ideally. I know it's just sugar around a lambda, but it spells out intent and makes code read more nicely.
On Tue, Nov 5, 2019, 5:26 PM Junpeng Lao notifications@github.com wrote:
Thanks for raising the issue! I agree completely - I have been discussing with some other users that sees similar need (e.g., https://twitter.com/ML_deep/status/1188387178507694081?s=20) @csuter https://github.com/csuter @brianwa84 https://github.com/brianwa84 @jvdillon https://github.com/jvdillon
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/640?email_source=notifications&email_token=AFJFSI7DKVAMT74CIE2XWSTQSHXKXA5CNFSM4JJJDPZKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEDER3WA#issuecomment-550051288, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI5IWZM6LD3HQGPYYADQSHXKXANCNFSM4JJJDPZA .
Fwiw, I strongly urge against this sugar. (I'll even beg if it helps.) To me, the connection between closures and unnormalized densities is the second most elegant part of TFP.
I also disagree it makes code more readable. While the lambda spells it out clearly how the thing is working, the cond_lp seems to me only to obfuscate. Also, the created object is not a distribution as the density is unnormalized. Finally, the simplified version isn't really a fair comparison since the cast and expand_dims would need to be done there. A better comparison would be:
lp = lambda *x: model.log_prob(x + (df[y'],))
and that reads purdy darn nicely to me!
Im happy to go down a long list of other reasons, but the tl;dr is that I claim this sugar only feels like it solves a problem but actually adds cognitive burden (yet another thing to learn), runs the risk of making a user think its required, and obfuscates what is otherwise a one-liner. If we applied konmari to software design, I claim we'd be quite happy with lambdas.
If the only thing this sugar does is construct a conditioned unnormalized log-probability function, it is not super useful. However, if it simultaneously handles other properties of the distribution (e.g. the shapes, dtypes etc), then it becomes more compelling. I've had good experience using something I call JointDistributionPosterior
which takes the conditioning as a constructor arg, and produces a distribution-like object, e.g.:
jd = JointDistribution(...)
jd.event_shape == [(1, 1), (2,), (3,)]
jd.dtype == [tf.int32, tf.float32, tf.float64]
jdp = JointDistributionPosterior(jd, conditioning=(None, tf.zeros([2]), None))
jdp.event_shape == [(1, 1), (3,)]
jdp.dtype == [tf.int32, tf.float64]
jdp.unnormalized_log_prob(
tf.nest.map_structure(lambda s, d: tf.zeros(s, dtype=d),
jdp.event_shape, jdp.dtype)) == tf.Tensor
It's very easy to write something like that yourself even if we never add it to TFP.
I think the problem current is that, to construct an (unnormalized) conditional posterior for inference, the APIs are quite inconsistent for different JD* flavor, and it also require user to understand the call signature:
init_state = [var1, var2, ...] # <== a list of tensors
lp = lambda x: mdl_jdseq.log_prob(
x + [observed])
lp(init_state) # <== this works but not when you plug it into mcmc.sample,
# which means user will get error downstream (much) later.
lp = lambda *x: mdl_jdseq.log_prob(
x + (observed, ))
lp = lambda *x: mdl_jdseq.log_prob(
list(x) + [observed]) # <== Another alternative, which arguably the "right"
# version as mdl_jdseq.sample([...]) returns a list.
# So by the contract of jd.log_prob(jd.sample([...]))
# the input to jd.log_prob should also be a list.
lp(*init_state)
# Not sure about what is the best practice here, as there are many way to
# construct a dict-like object for jd_named.log_prob - Nonetheless additional
# user input is needed here
import collections
Model = collections.namedtuple('Model', [...])
lp = lambda *x: mdl_jdname.log_prob(
Model(*x, observed))
lp(*init_state)
lp = lambda x: mdl_jdcoroutine.log_prob(
x + [observed])
lp(init_state) # <== this works but not when you plug it into mcmc.sample,
# which means user will get error downstream (much) later.
lp = lambda *x: mdl_jdcoroutine.log_prob(
x + (observed, )) # <== the canonical version as coroutine jd samples are tuple
lp = lambda *x: mdl_jdcoroutine.log_prob(
list(x) + [observed])
lp(*init_state)
A syntactic sugar would make sure it is consistence for all JD*
If we don't introduce an additional API for this, we should definitely make this more clear in doc-strings and documentations.
BTW, all the code above is basing on the assumption that the last node(s) is the observed.
+1 to better docstrings. I agree there's a learning curve here, but I feel this learning curve is "worth it" since the current approach ensures the unnormalized posterior is merely a thin accessor to the full joint (this being the inferential base). Furthermore, by not codifying this accessor we emphasize that all downstream inference logic is agnostic--any function will suffice.
As for the different call styles, I see this difference as one of the key points of having different JD flavors. The reason for the current style is that we wanted to preserve the d.log_prob(d.sample())
pattern yet also have d.sample()
be interpretable wrt the model
as supplied to __init__
. If it turns out this difference is more pain than benefit, Id rather see us change the JointDistribution than build new sugar on top.
In many TFP bayesian use cases, it's very helpful to specify a joint distribution using a
JointDistribution*
object - it makes sampling (for prior-predictive checks) straightforward (sorry), and exposes a log_prob function necessary for the sampler. However, since many (most?) of these cases involve some sort of conditioning, we end up writing a function closure which is very confusing and possibly error prone (mostly shape errors, but also type errors):Exposing some sort of conditioning method, instead, could be amazing.
For example, for a JointDistributionSequential (which is represented by a list), perhaps something along these lines:
?
Thanks in advance!