pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 985 forks source link

Provide default implementation of batch_log_pdf #144

Closed fritzo closed 7 years ago

fritzo commented 7 years ago

Could we provide a default implementation of batch_log_pdf as a simple for loop?

class Distribution(object):
    ...
    def batch_log_pdf(self, x, batch_size, *args, **kwargs):
        result = torch.Tensor([batch_size])
        for i in range(batch_size):
            result[i] = self.log_pdf(x[i], *args, **kwargs)
        return torch.autograd.Variable(result)   # Caller decides whether to .sum().

Or do we want to instead implement correct handling of NotImplementedErrors everywhere batch_log_pdf is used?

Disclaimer: I don't understand what batch_log_pdf does, and there is no docstring.

Edited to not sum the result.

neerajprad commented 7 years ago

Since I have been refactoring the tests for distributions, I thought I'll chime in. I think the batch_log_pdf are overloaded to also implement the correct vector scaling for ScalePoutine. I wonder if we can have just a single function for computing pdf (which operates on tensors. basically does what batch_log_pdf is doing minus the scaling), and a separate utility function that scales this for the Poutines appropriately.

@jpchen, @eb8680 are much more familiar with this, and can comment on whether this is feasible.

jpchen commented 7 years ago

it does more than that. essentially batch_log_pdf is if you ran log_pdf in parallel across dim=1 in the tensor. The reason it's not implemented for every distribution is that there's some hairiness with higher dimension choleskys in parallel (eg in Normal)

eb8680 commented 7 years ago

To add to @jpchen's explanation, batch_log_pdf is necessary for, among other things, gradient estimators that take advantage of vectorized map_data.

Rather than a generic implementation of batch_log_pdf, we should probably have a generic implementation of log_pdf which sums along the batch dimension. As @jpchen says, some distributions do have wrinkles, but we can just override the generic implementations in that case.

This is a separate issue, but regarding Normal, batched Cholesky will probably be very useful to us at some point for fast GP implementations. Maybe we should open a PyTorch issue about it?

neerajprad commented 7 years ago

+1. It looks like for most distributions, they share a lot of the code except the dimension to be summed, and this scaling logic. I am sure there will be hairiness with some distributions, but will be great to consolidate that, and expose a single interface like scipy.stats.

        if x.dim() == 1 and _mu.dim() == 1 and batch_size == 1:
            return self.log_pdf(x, _mu, _sigma)
        elif x.dim() == 1:
            x = x.expand(batch_size, x.size(0))
dustinvtran commented 7 years ago

Rather than a generic implementation of batch_log_pdf, we should probably have a generic implementation of log_pdf which sums along the batch dimension.

An alternative option is to have all methods return the vectorized output and require downstream the summation if that's desired. This avoids explicitly implementing log_pdf vs batch_log_pdf for functions like log_pdf, pdf, log_cdf, cdf, log_survival_function, etc.

fritzo commented 7 years ago

Good point. I think the original pyro devs intend batch_log_pdf to behave as you describe. I like the idea of batching within log_pdf say via an extra batch_axis or axis arg, rather than via a separate function. I'm curious, in Edward did you find a good consistent convention for which axes to batch over, or do you need to specify an axis arg?

dustinvtran commented 7 years ago

Our decision for the convention was ultimately not to have a convention. :) That is, by requiring summation downstream, I mean the log prob is always returns vectorized. During inference implementations such as build_reparam_loss_and_gradients, we have lines like q_log_prob += tf.reduce_sum(qz.log_prob(qz_sample)). I've personally found this design to have the most flexibility and least cognitive overhead.

We do assume batch_shape is the outer-most (left-most) dimensions of the input to log_prob if that's what you mean.

jpchen commented 7 years ago

this is possible now especially with pytorch's new broadcasting support. in the past, 1d vectors and 2d vectors were not broadcastable so identical operations would require two separate implementations involving a series of expand and squeezes

@fritzo what dustin is saying is that the summing over batch dimension is handled by downstream inference algorithms which consume the vector log_pdf would return so the distribution doesnt know about the batches. of course this means that the logic for handling the right batch dimension is put into each inference implementation which i feel might have its own disadvantages... we can also make the explicit assumption that batch_dim = 0 (as we do now) and as edward does

ngoodman commented 7 years ago

great point. seems like the answer is to figure out distribution shapes ala #153, and then the correct log_pdf method follows? (ie stay vectorized over the repeated dims, and have objectives be aware of this.)