TyXe-BDL / TyXe

MIT License
144 stars 33 forks source link

Implementing Radial BNN #12

Open silasbrack opened 2 years ago

silasbrack commented 2 years ago

Hi,

I’m trying to fit a radial BNN posterior variational approximation as per this paper.

However, since I’ll be training a BNN, I don’t want to have to write a custom guide and define this variational approximation for all of my layers, and so was trying to implement a custom AutoGuide which automatically puts a radial BNN approximation on all of my weights.

The radial approximation is defined as follows: image where I just need to sample all epsilon_MFVI from an independent standard normal distribution, normalize them, and multiply them by r, which is a scalar sampled from a standard normal.

How could I go about implementing this in TyXe? Is there a smarter way of implementing this variational approximation?

P.S. Big fan of this project!

Thanks in advance.

hpplyt commented 2 years ago

Hi, this should overall be doable. Essentially you'd need to 1) implement a RadialNormal distribution class that inherits from pyro.distributions.Distribution. The easiest approach would probably be to inherit from pyro.distributions.Normal and overwrite the rsample and log_prob methods, as far as I remember the parameterization is the same as for a Normal distribution, so you can inherit the boilerplate code. 2) implement a corresponding autoguide class. If you just want something quick and dirty that works, you should be able to subclass pyro.infer.autoguide.AutoNormal, copy-paste their forward method and change the line where they instantiate the dist.Normal to use your RadialNormal instead.

And then the AutoRadial guide would (hopefully :-) ) work in place of an AutoNormal guide as in the examples.


(2) is unfortunately a bit ugly, ideally we'd have some kind of autoguide factory class in tyxe that can generate autoguides for a given distribution to make adding custom distribution easier. I'll give this some more thought when I get the chance.

As an additional note on (1), you might want to also implement the KL divergence between the RadialNormal and a Normal distribution (I think that's what they use as a prior in the paper). For that you need to implement the kl as a decorated method like:

@torch.distributions.kl.register_kl(RadialNormal, dist.Normal)
def _kl(q, p):
    ...

where q is a RadialNormal and p a Normal object.

Sorry this is all a bit more involved than it should be, but I hope this helps. If you need any more details or if I overlooked any issues, let me know. And if you have a go at an implementation, feel free to link your repo/a gist here, I'm happy to take a look at it.

silasbrack commented 2 years ago

Hey, sorry for the delay; thanks a lot for the help!

I managed to implement a script for running VI with BNNs with this radial approximation with your tips. Fortunately, the KL divergence between the radial posterior and a normal prior is the same as for the mean-field (up to a constant), so I didn't bother updating the calculation of the KL divergence, just the sampling.

I've actually had great results with this posterior. In general, it seems to me that the mean-field approximation often struggles to converge to an accurate solution and the radial posterior consistently seems to outperform both mean-field and low-rank approximations.

Feel free to take a look at it in https://github.com/silasbrack/approximate-inference-for-bayesian-neural-networks/blob/main/src/guides/radial.py