stefanknegt / Probabilistic-Unet-Pytorch

A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch
Apache License 2.0
270 stars 67 forks source link

KL Divergence for Independent #13

Closed schenock closed 3 years ago

schenock commented 3 years ago

For parametrizing an axis aligned Gaussian you are using a Normal wrapped into an Indepedent, and add a patch for the undefined KL diveregence.

I was wondering isn't it possible to achieve the same (axis aligned multivariate gaussian) using a MultivariateNormal instance? For example:

mu = torch.zeros(batch_size, latent_dim)
log_sigma = torch.ones(batch_size, latent_dim)
cov = torch.stack([torch.diag(sigma) for sigma in torch.exp(log_sigma)])

mvn = MultivariateNormal(mu, cov)

mvn.batch_shape, mvn.event_shape (torch.Size([batch_size]), torch.Size([latent_dim]))

considering KL is defined for a (MultivariateNormal, MultivariateNormal)

schenock commented 3 years ago

closing this. see discussion here: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch/issues/1