piEsposito / blitz-bayesian-deep-learning

A simple and extensible library to create Bayesian Neural Network layers on PyTorch.
GNU General Public License v3.0
921 stars 106 forks source link

freeze_, unfreeze_ vs model.eval()/model.train() #69

Open pythonometrist opened 3 years ago

pythonometrist commented 3 years ago

Thanks for an excellent and well thought out framework.

It looks like from a model evaluation perspective - freeze/unfreeze have no role to play . To predict on a held out set it is enough to set torch.no_grad() and model.eval()

After all, all trainable parameters including those of the posterior distributions rho and mu are not affected by freeze/unfreeze (unless we call model.eval()) Is that correct?

piEsposito commented 3 years ago

Hello @pythonometrist , and thank you so much for using BLiTZ.

So, the .eval method relates mostly to dropout (and batchnorm I think) on Torch built-in layers, to stop the randomness and avoid data leakage on the nn.

BLiTZ's freeze and unfreeze model relates to enabling and disabling the reparameterization on the Bayesian Layers. When we freeze a layer, it uses only the mean of the trainable random distribution as the weights for the layer, while when using unfreeze, it enables the random sampling.

This behavior controlling is important specifically when we want to specify bayesian priors for our surrogate distribution, via training the frozen network on the data before enabling randomness and uncertainty.

Hope I' ve helped.