Closed fritzo closed 7 years ago
Since I have been refactoring the tests for distributions, I thought I'll chime in. I think the batch_log_pdf
are overloaded to also implement the correct vector scaling for ScalePoutine. I wonder if we can have just a single function for computing pdf (which operates on tensors. basically does what batch_log_pdf
is doing minus the scaling), and a separate utility function that scales this for the Poutines appropriately.
@jpchen, @eb8680 are much more familiar with this, and can comment on whether this is feasible.
it does more than that. essentially batch_log_pdf
is if you ran log_pdf
in parallel across dim=1
in the tensor. The reason it's not implemented for every distribution is that there's some hairiness with higher dimension choleskys in parallel (eg in Normal
)
To add to @jpchen's explanation, batch_log_pdf
is necessary for, among other things, gradient estimators that take advantage of vectorized map_data
.
Rather than a generic implementation of batch_log_pdf
, we should probably have a generic implementation of log_pdf
which sums along the batch dimension. As @jpchen says, some distributions do have wrinkles, but we can just override the generic implementations in that case.
This is a separate issue, but regarding Normal
, batched Cholesky will probably be very useful to us at some point for fast GP implementations. Maybe we should open a PyTorch issue about it?
+1. It looks like for most distributions, they share a lot of the code except the dimension to be summed, and this scaling logic. I am sure there will be hairiness with some distributions, but will be great to consolidate that, and expose a single interface like scipy.stats.
if x.dim() == 1 and _mu.dim() == 1 and batch_size == 1:
return self.log_pdf(x, _mu, _sigma)
elif x.dim() == 1:
x = x.expand(batch_size, x.size(0))
Rather than a generic implementation of batch_log_pdf, we should probably have a generic implementation of log_pdf which sums along the batch dimension.
An alternative option is to have all methods return the vectorized output and require downstream the summation if that's desired. This avoids explicitly implementing log_pdf
vs batch_log_pdf
for functions like log_pdf
, pdf
, log_cdf
, cdf
, log_survival_function
, etc.
Good point. I think the original pyro devs intend batch_log_pdf
to behave as you describe. I like the idea of batching within log_pdf
say via an extra batch_axis
or axis
arg, rather than via a separate function. I'm curious, in Edward did you find a good consistent convention for which axes to batch over, or do you need to specify an axis
arg?
Our decision for the convention was ultimately not to have a convention. :) That is, by requiring summation downstream, I mean the log prob is always returns vectorized. During inference implementations such as build_reparam_loss_and_gradients
, we have lines like q_log_prob += tf.reduce_sum(qz.log_prob(qz_sample))
. I've personally found this design to have the most flexibility and least cognitive overhead.
We do assume batch_shape
is the outer-most (left-most) dimensions of the input to log_prob
if that's what you mean.
this is possible now especially with pytorch's new broadcasting support. in the past, 1d vectors and 2d vectors were not broadcastable so identical operations would require two separate implementations involving a series of expand
and squeeze
s
@fritzo what dustin is saying is that the summing over batch dimension is handled by downstream inference algorithms which consume the vector log_pdf
would return so the distribution doesnt know about the batches. of course this means that the logic for handling the right batch dimension is put into each inference implementation which i feel might have its own disadvantages... we can also make the explicit assumption that batch_dim = 0
(as we do now) and as edward does
great point. seems like the answer is to figure out distribution shapes ala #153, and then the correct log_pdf method follows? (ie stay vectorized over the repeated dims, and have objectives be aware of this.)
Could we provide a default implementation of
batch_log_pdf
as a simple for loop?Or do we want to instead implement correct handling of
NotImplementedError
s everywherebatch_log_pdf
is used?Disclaimer: I don't understand what
batch_log_pdf
does, and there is no docstring.Edited to not sum the result.