JavierAntoran / Bayesian-Neural-Networks

Pytorch implementations of Bayes By Backprop, MC Dropout, SGLD, the Local Reparametrization Trick, KF-Laplace, SG-HMC and more
MIT License
1.83k stars 302 forks source link

[Question] BBB vs BBB w/ Local Reparameterization #14

Open danielkelshaw opened 4 years ago

danielkelshaw commented 4 years ago

Hi @JavierAntoran @stratisMarkou,

First of all, thanks for making all of this code available - it's been great to look through!

Im currently spending some time trying to work through the Weight Uncertainty in Neural Networks in order to implement Bayes-by-Backprop. I was struggling to understand the difference between your implementation of Bayes-by-Backprop and Bayes-by-Backprop with Local Reparameterization.

I was under the impression that the local reparameterization was the following:

https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/022b9cedb69be3fecc83d4b0efe4b5a848119c2a/src/Bayes_By_Backprop/model.py#L58-L66

However this same approach is used in both methods.

The main difference I see in the code you've implemented is the calculation of the KL Divergence in closed form in the Local Reparameterization version of the code due to the use of a Gaussian prior / posterior distribution.

I was wondering if my understanding of the local reparameterization method was wrong, or if I had simply misunderstood the code?

Any guidance would be much appreciated!

danielkelshaw commented 4 years ago

Furthermore, your implementation of the closed form KL Divergence is the same as seen in equation 10 of the Auto-Encoding Variational Bayes paper:

https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/022b9cedb69be3fecc83d4b0efe4b5a848119c2a/src/Bayes_By_Backprop_Local_Reparametrization/model.py#L25-L28

I was wondering if you could provide any detail on how you arrived at the equation that you implemented in the code?

Thanks again!

JavierAntoran commented 4 years ago

Hi @danielkelshaw,

Thanks for your question. Similarly to the regular reparametrisation trick, the local reparametrisation trick is used to estimate gradients with respect to parameters of a distribution. However, the local reparametrisation trick takes advantage of the fact that, for a fixed input and Gaussian distributions over the weights, the resulting distribution over activations is also Gaussian. Instead of sampling all the weights individually and then combining them with the inputs to compute a sample from the activations, we can directly sample from the distribution over activations. This results in a lower variance gradient estimator which in turn makes training faster and more stable. Using the local reparametrisation trick is always recommended if possible.

The code for both gradient estimators is similar but not quite the same. In the code you referenced, if you look closely, you can see that we first sample the Gaussian weights:

        W = self.W_mu + 1 * std_w * eps_W
        b = self.b_mu + 1 * std_b * eps_b

And then pass the input through a linear layer with parameters that we just sampled:

        output = torch.mm(X, W) + b.unsqueeze(0).expand(X.shape[0], -1)  # (batch_size, n_output) 

On the other hand, for the local reparametrisation trick, we compute the parameters of the Gaussian over activations directly:

        act_W_mu = torch.mm(X, self.W_mu)  # self.W_mu + std_w * eps_W
        act_W_std = torch.sqrt(torch.mm(X.pow(2), std_w.pow(2)))

And then sample from the distribution over activations.

        act_W_out = act_W_mu + act_W_std * eps_W  # (batch_size, n_output)
        act_b_out = self.b_mu + std_b * eps_b

        output = act_W_out + act_b_out.unsqueeze(0).expand(X.shape[0], -1)

With regard to the KL divergence, the form used in regular BayesByBackprop is more general but requires using MC sampling to estimate it. It has the benefit of allowing for non-Gaussian priors and non-Gaussian approximate posteriors (note that our code implements the former but not the latter). We use the same weight samples to compute the model predictions and KL divergence here, saving compute and reducing variance due to the law of common random numbers.

When running the local reparametrization trick, we sample activations instead of weights. Thus, we don't have access to weight samples needed to estimate the KL divergence. Because of this, we opted for the closed-form implementation. It restricts us to the Gaussian prior but has lower variance and results in faster convergence.

With regard to your second question: the KL divergence between 2 Gaussians can be obtained in closed form by solving a Gaussian integral. See: https://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians

danielkelshaw commented 4 years ago

@JavierAntoran - thank you for taking the time to help explain this, I really appreciate it!

I found your explanation of the local reparameterisation trick very intuitive and feel like I've got a much better grasp of that now.

I'm very interested in learning more about Bayesian Neural Networks, I was wondering if you had any recommended reading that would help get me up to speed with some more of the theory?

JavierAntoran commented 4 years ago

For general ideas about re-casting learning as inference, I would check out chapter 41 of David MacKay's Information Theory, Inference, and Learning Algorithms. Yarin Gal's thesis is also a good source.

On the more practical side, the tutorial made by the guys at Papercup is quite nice.

Other than that, read the papers implemented in this repo and try to understand both the algorithm and implementation.