pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 984 forks source link

Problems with MeanFieldELBO [bug] with TransformedDistributions in guide and model #2149

Closed Mi-Przystupa closed 4 years ago

Mi-Przystupa commented 4 years ago

This is a copy of an issue I posted for pytorch: https://github.com/pytorch/pytorch/issues/29698

I think this is an issue with the _kl_independent_independent class which is summing over batches when it shouldn't...

Guidelines

I think the Kl divergence is summing over the batch dimensions when it shouldn't, at least for the Gaussian case

Here's the transformed distribution function:

def _kl_transformed_transformed(p, q):
    if p.transforms != q.transforms:
        raise NotImplementedError
    if p.event_shape != q.event_shape:
        raise NotImplementedError
    # extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
    extra_event_dim = len(p.event_shape)
    base_kl_divergence = kl_divergence(p.base_dist, q.base_dist) #call to indep_indep below
   #this will again sum over kl_divergence for each entry in batch
    return _sum_rightmost(base_kl_divergence, extra_event_dim)

Here's independent_independent KL

@register_kl(Independent, Independent)
def _kl_independent_independent(p, q):
    shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
    p_ndims = p.reinterpreted_batch_ndims - shared_ndims
    q_ndims = q.reinterpreted_batch_ndims - shared_ndims
    p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist
    q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist
    kl = kl_divergence(p, q)
    if shared_ndims:
       #this line gets called when base_dist is Gaussian
        kl = sum_rightmost(kl, shared_ndims)
    return kl

To Reproduce

I think something like this will do it:

p = Gaussian(torch.zeros(3, 10), torch.ones(3, 10)).independent(1)
q = Gaussian(torch.ones(3,10) *2, torch.ones(3,10).independent(1)
kl = kl_divergence(p, q)
print(kl.shape) #should be something like [3, 1] but will instead output []

expected behavior

Maybe I misunderstand kl_divergence function, but I don't think it should be summing over batch.

fehiepsi commented 4 years ago

@gamerDecathlete I think the issue happens at your construction of Gaussian. I test with Normal distribution and see the expected result:

import torch
from torch.distributions import kl_divergence
from pyro.distributions import Normal

p = Normal(torch.zeros(3, 10), torch.ones(3, 10)).independent(1)
q = Normal(torch.ones(3,10) *2, torch.ones(3,10)).independent(1)
kl = kl_divergence(p, q)
assert kl.shape == (3,)

Could you share your implementation and ask the question in forum instead?

Mi-Przystupa commented 4 years ago

Sorry, I forgot the Transformed Class. this is more close to how it looks in my model

 import torch
 from torch.distributions import kl_divergence
 from pyro.distributions import Normal
 from pyro.distributions.transforms import PlanarFlow
 from  pyro.distributions import TransformedDistribution

 flows = [PlanarFlow(10), PlanarFlow(10)]
 p_base = Normal(torch.zeros(3, 10), torch.ones(3, 10)).independent(1)
q_base = Normal(torch.ones(3, 10) * 2, torch.ones(3, 10)).independent(1)
p = TransformedDistribution(p_base, flows)
q = TransformedDistribution(q_base, flows)
 kl = kl_divergence(q, p)
assert kl.shape == (3,) #this should fail.

If it still isn't reproducible then, maybe my versions are just dated.

fehiepsi commented 4 years ago

Thanks, I got it now! Looking like the commented line

# extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)

is the correct behavior for kl of transformed-transformed.

fritzo commented 4 years ago

Hi @gamerDecathlete, can you paste your code for Gaussian? I have not seen that distribution. Also, have you tried first running your model with Trace_ELBO and pyro.enable_validation(True) to ensure there are no other shape errors.

The reason _kl_independent_independent sums out some dimensions is that the rightmost reinterpreted_batch_ndims-many dimensions are actuall event dimensions, not batch dimensions. Internally, we compute kl_divergence(p.base_dist, q.base_dist) which has too many batch dimension. Then we sum out some of those to match the Independent assumptions.

Mi-Przystupa commented 4 years ago

@fritzo I posted the exact snippet that @fehiepsi reported does the behavior I was commenting on in relation to the MeanFieldELBO. I'll try the enable_validation think you mention, but otherwise I'm not sure I quite understand your comment. You are right, the Gaussian I don't think exists, I meant Normal (thanks for clarifying that).

fritzo commented 4 years ago

I believe #2165 fixed this (thanks @fehiepsi !). Feel free to reopen if you continue to have issues.