I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal, OneHotCategorical, MixtureSameFamily, Categorical
from torch.distributions.independent import Independent
class VGMM(nn.Module):
def __init__(self,
u_dim,
h_dim,
z_dim,
n_mixtures,
device,
batch_norm=False,
):
super(VGMM, self).__init__()
self.n_mixtures =n_mixtures
self.u_dim= u_dim
self.h_dim=h_dim
self.z_dim=z_dim
self.device=device
self.batch_norm= batch_norm
encoder_layers=[nn.Linear(self.u_dim , self.h_dim)]
if self.batch_norm:
encoder_layers.append(torch.nn.BatchNorm1d(self.h_dim))
encoder_layers=encoder_layers+[
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
]
if self.batch_norm:
encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]
encoder_layers = encoder_layers+[nn.ReLU()]
self.enc = torch.nn.Sequential(*encoder_layers)
self.enc_mean = nn.Linear(self.h_dim, self.z_dim)
self.enc_logvar = nn.Linear(self.h_dim, self.z_dim)
self.dist = MixtureSameFamily
self.comp = Normal
self.mix = Categorical
layers_prior = [nn.Linear(self.u_dim, self.h_dim)]
if self.batch_norm:
layers_prior.append(torch.nn.BatchNorm1d(self.h_dim))
layers_prior = layers_prior + [
nn.ReLU(),
]
self.prior = torch.nn.Sequential(*layers_prior)
self.prior_mean = nn.ModuleList(
[nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
)
self.prior_logvar = nn.ModuleList(
[nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
)
self.prior_weights = nn.Linear(self.h_dim, n_mixtures)
def forward(self, u):
encoder_input = self.enc(u)
enc_mean = self.enc_mean(encoder_input)
enc_logvar = self.enc_logvar(encoder_input)
enc_logvar = nn.Softplus()(enc_logvar)
prior_input =self.prior(u)
prior_mean = torch.cat([ self.prior_mean[n](prior_input).unsqueeze(1) for n in range(self.n_mixtures)],dim=1,)
prior_logvar = torch.cat([self.prior_logvar[n](prior_input).unsqueeze(1)for n in range(self.n_mixtures)],dim=1,)
prior_w = self.prior_weights(prior_input)
prior_sigma = prior_logvar.exp().sqrt()
prior_dist = self.dist(self.mix(logits=prior_w), Independent(self.comp(prior_mean, prior_sigma), 1))
post_dist = self.comp(enc_mean, enc_logvar.exp().sqrt())
z_t = self.reparametrization(enc_mean, enc_logvar)
return prior_dist, post_dist, z_t
def reparametrization(self, mu, log_var):
var = torch.exp(log_var* 0.5)
eps = torch.FloatTensor(var.size()).normal_(mean=0, std=1).to(self.device)
eps = torch.autograd.Variable(eps)
return eps.mul(var).add_(mu).add_(1e-7)
How do you suggest I can use library to compute the KL term? Thanks in advance.
Hi,
I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?
How do you suggest I can use library to compute the KL term? Thanks in advance.