danielkelshaw / WeightUncertainty

PyTorch implementation of 'Weight Uncertainty in Neural Networks'
https://arxiv.org/abs/1505.05424
MIT License
16 stars 2 forks source link

Question on 3.2 of the Paper #31

Closed PaulScemama closed 1 year ago

PaulScemama commented 1 year ago

Hey I really enjoyed going through your implementation of the paper, so thank you!

I had a quick question: in section 3.2 of the paper, the authors provide an optimization procedure. I understand steps $1-3$ where they reparameterize the Diagonal Gaussian. Then in $4$ they let $f(\pmb{w}, \theta) = \text{log} q(\pmb{w}, \theta) - \text{log} P(\pmb{w})P(\mathcal{D}|\pmb{w})$. I understand this to be a different form of $(2)$ but for a single sample $\pmb{w}$. Am I correct in thinking this?

Then in $5-6$ they define the gradient with respect to the variational parameter $\mu$ as $\nabla_{\mu} = \frac{\partial f(\pmb{w}, \theta)}{\partial \pmb{w}} + \frac{\partial f(\pmb{w}, \theta)}{\mu}$ and then also a similar expression for $\rho$. I see $\frac{\partial f(\pmb{w}, \theta)}{\partial \pmb{w}}$ as the gradient of the loss $(2)$ with respect to the sampled weight $\pmb{w}$.

My question is then, where do you do the calculations in equation $(3)$ of the paper? It seems to me you just take the gradient of the loss function with respect to $\mu$?

I'm sure I'm not understanding something, so thanks for any light you can shed on my problem! Thanks!

danielkelshaw commented 1 year ago

Hey, thanks for reaching out! It's been quite a while since I've looked at the paper but I'll do my best to clarify any points.

I had a quick question: in section 3.2 of the paper, the authors provide an optimization procedure. I understand steps 1−3 where they reparameterize the Diagonal Gaussian. Then in 4 they let f(ww,θ)=logq(ww,θ)−logP(ww)P(D|ww). I understand this to be a different form of (2) but for a single sample ww. Am I correct in thinking this?

As you said steps 1-3 are simply the re-parameterisation trick, allowing us to compute the gradient wrt. the parameters and alleviating the issues of differentiating wrt. a random sample. You're correct in saying that step 4 is the same as equation (2) but for a single sample -- they state in the paper: "Since each additive term in the approximate cost in (2) uses the same weight samples, the gradients of (2) are only affected by the parts of the posterior distribution characterised by the weight samples." I take this to mean that, since we are ultimately only interested in computing gradients of the ELBO, a single sample is sufficient.

Then in 5−6 they define the gradient with respect to the variational parameter μ as ∇μ=∂f(ww,θ)∂ww+∂f(ww,θ)μ and then also a similar expression for ρ. I see ∂f(ww,θ)∂ww as the gradient of the loss (2) with respect to the sampled weight ww.

My question is then, where do you do the calculations in equation (3) of the paper? It seems to me you just take the gradient of the loss function with respect to μ?

Let me first provide the code of the forward pass for convenience:

https://github.com/danielkelshaw/WeightUncertainty/blob/39a1cb54b257e27d67d3a84135be4709e7cc90b4/torchwu/bayes_linear.py#L59-L87

We see in L84 that the log posterior is a function of w_log_posterior, b_log_posterior; computed as per steps 1-3. In L51-L52 we instantiate instances of the GaussianVariational class which in turn registers parameters self.mu, self.rho; these are used to produce the weights and biases used for the linear layer. When we call .backward() in the training loop the gradients of the loss wrt. these parameters of the network will be computed via autograd. So while steps 5-6 in the algorithm are providing analytical formulations for the gradients $\nabla\mu, \nabla\rho$, these are being computed internally and we do not need to write this explicitly.

Looking through the code again I can see that this might be unclear. Please let me know if I've cleared things up - let me know if you need a better explanation. While I'm unlikely to update the repo myself at this time I would welcome a pull request if you think you can make it easier to understand in the code.

PaulScemama commented 1 year ago

Hey Daniel, thanks for the quick and instructive response! You've definitely cleared things up. Ah, so it would appear that torch has built in gradients for the reparameterization trick so you don't have to explicitly take the gradient $\nabla{\pmb{w}}$ and then scale + shift as in the analytical formulations in the paper to get $\nabla{\mu}$ and $\nabla_{\rho}$?

I completely understand you'd not update the repo yourself any time soon. I am just now really getting interested in bayesian deep learning and more generally articulating uncertainty in high-dimensional ML problems.

I have started writing up my own implementation here https://github.com/PaulScemama/nn-weight-uncertainty which has taken inspiration from some of your code as well as others. It's been very instructive in understanding how to go from probabilistic formulas / loss functions to actual implementation. I will of course cite/reference you heavily although I don't think anyone will look at my code :). Is that okay?

Let me know what you think. And thanks again - I've been sifting through some of your other repos and the clarity and diligence of your code is really great and inspiring!

danielkelshaw commented 1 year ago

Hey, sorry for the delay getting back to you. So torch doesn't have that mechanism built in, but it works because we define the reparameterisation trick ourselves:

https://github.com/danielkelshaw/WeightUncertainty/blob/39a1cb54b257e27d67d3a84135be4709e7cc90b4/torchwu/samplers/gaussian_variational.py#L40-L59

Both self.mu, self.rho are nn.Parameter, these are the objects that we want the gradients of - by writing our sampling as a function of these, torch is able to backpropogate wrt. these parameters. If we sampled N(mu, sigma) then torch would not be able to differentiate wrt. mu, sigma. I hope that's more clear?

I'll take a look at your repo in the next days, but looks like you're making a great start! Absolutely fine for you to use my code - hope it's helpful.

Let me know if you've got any other questions, happy to help!

PaulScemama commented 1 year ago

No worries at all!

Ah yes that definitely makes sense.

So let's see if I got this right: because self.w is "sampled" via a deterministic equation w.r.t. self.mu and self.rho and then adding some independent, fixed noise...instead of by taking a sample from a distribution parameterized by self.mu and self.rho, we can get the gradient of self.mu and self.rho. I.e. that's the point of the reparamaterization trick - to represent a distribution so that a sample from it is a deterministic function of the parameters of interest...then we can get the gradient for those parameters of interest via backpropagation.

I think torch.distributions now have an .rsample() method for some distributions (like Multivariate Normal) which samples via the reparamaterization trick for you automatically (see https://pytorch.org/docs/stable/distributions.html)! Although I do appreciate the explicitness in constructing the reparamaterized normal oneself.

And thank you so much! I'm currently struggling with it to even train at all :sweat_smile: but that was a first go of it. Hopefully I will be able to debug by the end of this weekend, but I think it is to do with the complexity cost term:

$$KL(q(w),p(w)) \approx \text{log-posterior} - \text{log-prior}$$ being very large and not really coming down during training. I've only taken a couple hours to try and figure it out so I will continue on. I will let you know if I figure it out / need some help :).

Thanks very much!

danielkelshaw commented 1 year ago

So let's see if I got this right: because self.w is "sampled" via a deterministic equation w.r.t. self.mu and self.rho and then adding some independent, fixed noise...instead of by taking a sample from a distribution parameterized by self.mu and self.rho, we can get the gradient of self.mu and self.rho. I.e. that's the point of the reparamaterization trick - to represent a distribution so that a sample from it is a deterministic function of the parameters of interest...then we can get the gradient for those parameters of interest via backpropagation.

To the best of my knowledge I think you're absolutely right here.

Interesting to know that they've incorporated it into a sampling method now, I'll take a look.

Good luck with getting it to train! I'll close this issue for now but feel free to re-open / make a new issue if you have anymore questions.