Closed schenock closed 3 years ago
For parametrizing an axis aligned Gaussian you are using a Normal wrapped into an Indepedent, and add a patch for the undefined KL diveregence.
Normal
Indepedent
I was wondering isn't it possible to achieve the same (axis aligned multivariate gaussian) using a MultivariateNormal instance? For example:
MultivariateNormal
mu = torch.zeros(batch_size, latent_dim) log_sigma = torch.ones(batch_size, latent_dim) cov = torch.stack([torch.diag(sigma) for sigma in torch.exp(log_sigma)]) mvn = MultivariateNormal(mu, cov)
mvn.batch_shape, mvn.event_shape (torch.Size([batch_size]), torch.Size([latent_dim]))
mvn.batch_shape, mvn.event_shape
considering KL is defined for a (MultivariateNormal, MultivariateNormal)
closing this. see discussion here: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch/issues/1
For parametrizing an axis aligned Gaussian you are using a
Normal
wrapped into anIndepedent
, and add a patch for the undefined KL diveregence.I was wondering isn't it possible to achieve the same (axis aligned multivariate gaussian) using a
MultivariateNormal
instance? For example:mvn.batch_shape, mvn.event_shape
(torch.Size([batch_size]), torch.Size([latent_dim]))considering KL is defined for a (
MultivariateNormal
,MultivariateNormal
)