Closed Mi-Przystupa closed 4 years ago
@gamerDecathlete I think the issue happens at your construction of Gaussian
. I test with Normal distribution and see the expected result:
import torch
from torch.distributions import kl_divergence
from pyro.distributions import Normal
p = Normal(torch.zeros(3, 10), torch.ones(3, 10)).independent(1)
q = Normal(torch.ones(3,10) *2, torch.ones(3,10)).independent(1)
kl = kl_divergence(p, q)
assert kl.shape == (3,)
Could you share your implementation and ask the question in forum instead?
Sorry, I forgot the Transformed Class. this is more close to how it looks in my model
import torch
from torch.distributions import kl_divergence
from pyro.distributions import Normal
from pyro.distributions.transforms import PlanarFlow
from pyro.distributions import TransformedDistribution
flows = [PlanarFlow(10), PlanarFlow(10)]
p_base = Normal(torch.zeros(3, 10), torch.ones(3, 10)).independent(1)
q_base = Normal(torch.ones(3, 10) * 2, torch.ones(3, 10)).independent(1)
p = TransformedDistribution(p_base, flows)
q = TransformedDistribution(q_base, flows)
kl = kl_divergence(q, p)
assert kl.shape == (3,) #this should fail.
If it still isn't reproducible then, maybe my versions are just dated.
Thanks, I got it now! Looking like the commented line
# extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
is the correct behavior for kl of transformed-transformed.
Hi @gamerDecathlete, can you paste your code for Gaussian
? I have not seen that distribution. Also, have you tried first running your model with Trace_ELBO
and pyro.enable_validation(True)
to ensure there are no other shape errors.
The reason _kl_independent_independent
sums out some dimensions is that the rightmost reinterpreted_batch_ndims
-many dimensions are actuall event dimensions, not batch dimensions. Internally, we compute kl_divergence(p.base_dist, q.base_dist)
which has too many batch dimension. Then we sum out some of those to match the Independent
assumptions.
@fritzo I posted the exact snippet that @fehiepsi reported does the behavior I was commenting on in relation to the MeanFieldELBO. I'll try the enable_validation think you mention, but otherwise I'm not sure I quite understand your comment. You are right, the Gaussian I don't think exists, I meant Normal (thanks for clarifying that).
I believe #2165 fixed this (thanks @fehiepsi !). Feel free to reopen if you continue to have issues.
This is a copy of an issue I posted for pytorch: https://github.com/pytorch/pytorch/issues/29698
I think this is an issue with the _kl_independent_independent class which is summing over batches when it shouldn't...
Guidelines
I think the Kl divergence is summing over the batch dimensions when it shouldn't, at least for the Gaussian case
Here's the transformed distribution function:
Here's independent_independent KL
To Reproduce
I think something like this will do it:
expected behavior
Maybe I misunderstand kl_divergence function, but I don't think it should be summing over batch.
conda
,pip
, source): conda/ pip