AdamCobb / hamiltorch

PyTorch-based library for Riemannian Manifold Hamiltonian Monte Carlo (RMHMC) and inference in Bayesian neural networks
BSD 2-Clause "Simplified" License
426 stars 63 forks source link

compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior #22

Open neuronphysics opened 1 year ago

neuronphysics commented 1 year ago

Hi,

I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal, OneHotCategorical, MixtureSameFamily, Categorical
from torch.distributions.independent import Independent
class VGMM(nn.Module): 
     def __init__(self,
                  u_dim,
                  h_dim,
                  z_dim,
                  n_mixtures,
                  device,
                  batch_norm=False,
                  ):
        super(VGMM, self).__init__()
        self.n_mixtures =n_mixtures
        self.u_dim= u_dim
        self.h_dim=h_dim
        self.z_dim=z_dim
        self.device=device
        self.batch_norm=  batch_norm
        encoder_layers=[nn.Linear(self.u_dim , self.h_dim)]
        if self.batch_norm:
            encoder_layers.append(torch.nn.BatchNorm1d(self.h_dim))
        encoder_layers=encoder_layers+[
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
        ]
        if self.batch_norm:
            encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]

        encoder_layers  = encoder_layers+[nn.ReLU()]

        self.enc        = torch.nn.Sequential(*encoder_layers)

        self.enc_mean   = nn.Linear(self.h_dim, self.z_dim)

        self.enc_logvar = nn.Linear(self.h_dim, self.z_dim)
        self.dist = MixtureSameFamily
        self.comp = Normal
        self.mix = Categorical

        layers_prior = [nn.Linear(self.u_dim, self.h_dim)]
        if self.batch_norm:
            layers_prior.append(torch.nn.BatchNorm1d(self.h_dim))
        layers_prior = layers_prior + [
            nn.ReLU(),
        ]

        self.prior = torch.nn.Sequential(*layers_prior)

        self.prior_mean = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )

        self.prior_logvar = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )
        self.prior_weights = nn.Linear(self.h_dim, n_mixtures) 
     def forward(self, u):
        encoder_input = self.enc(u)
        enc_mean   = self.enc_mean(encoder_input)
        enc_logvar = self.enc_logvar(encoder_input)
        enc_logvar = nn.Softplus()(enc_logvar)
        prior_input =self.prior(u)
        prior_mean  = torch.cat([ self.prior_mean[n](prior_input).unsqueeze(1) for n in range(self.n_mixtures)],dim=1,)
        prior_logvar = torch.cat([self.prior_logvar[n](prior_input).unsqueeze(1)for n in range(self.n_mixtures)],dim=1,)
        prior_w     = self.prior_weights(prior_input)
        prior_sigma = prior_logvar.exp().sqrt()
        prior_dist = self.dist(self.mix(logits=prior_w), Independent(self.comp(prior_mean, prior_sigma), 1))
        post_dist = self.comp(enc_mean, enc_logvar.exp().sqrt())
        z_t      = self.reparametrization(enc_mean, enc_logvar)
        return prior_dist, post_dist, z_t
     def reparametrization(self, mu, log_var):
        var = torch.exp(log_var* 0.5)
        eps = torch.FloatTensor(var.size()).normal_(mean=0, std=1).to(self.device)
        eps = torch.autograd.Variable(eps)
        return eps.mul(var).add_(mu).add_(1e-7)     

How do you suggest I can use library to compute the KL term? Thanks in advance.