team-approx-bayes / dl-with-bayes

Contains code for the NeurIPS 2019 paper "Practical Deep Learning with Bayesian Principles"
242 stars 23 forks source link

Small damping hyper-parameter leads to diverge in training loss #7

Open uryuk opened 3 years ago

uryuk commented 3 years ago

Hello @emtiyaz @kazukiosawa ! Thank you for this great work and your research!

I am trying to follow your work and apply it to a classification problem, hoping that uncertainty estimation will help me reduce OOD misclassifications. I first tried to apply VOGN but after a suggestion from your side in the parallel issue (https://github.com/team-approx-bayes/dl-with-bayes/issues/6) switched to OGN, trying to tune hyper-parameters.

However, I noticed that putting damping parameter low (1e-5 or lower) leads to divergence and the training loss grows instead of decreasing.

I observed similar behavior on the Imagenet training (which I used as a test bed to verify my assumptions), raising damping to 1e-3 or even 1e-1 helps to stabilize training and convergence.

I understand from your research that it should be other way around and small value for damping help to control close to zero eigenvalues.

I think I am missing something, can you please advice ?

My Parameters for OGN that keep training loss slowly decreasing:

"optim_name": "DistributedSecondOrderOptimizer", "optim_args": { "curv_type": "Cov", "curv_shapes": { "Conv2d": "Diag", "Linear": "Diag", "BatchNorm1d": "Diag", "BatchNorm2d": "Diag" }, "lr": 1.6e-3, "momentum": 0.9, "momentum_type": "raw", "non_reg_for_bn": true }, "curv_args": { "damping": 1e-1, "ema_decay": 0.999 },

kazukiosawa commented 3 years ago

@uryuk Thanks for your question! (and sorry for the late response)

Yes, increasing the damping parameter stabilizes the training.

To stabilize the training, you can also increase l2_reg (the coefficient of L2 regularization). Both the damping and (exponent moving average of) l2_reg will be added to the diagonal elements of the curvature before inverting it in SecondOrderOptimizer.

See these parts https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/curv/curvature.py#L103-L104 https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py#L246 https://github.com/cybertronai/pytorch-sso/blob/ab71354440600d14cfa276b3decbc8ec54122ce8/torchsso/curv/curvature.py#L230-L232

Also, see here to understand the relationship between l2_reg of SecondOrderOptimizer and the parameters of VIOptimizer when you want to try VOGN instead of OGN. https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py#L77

I hope this helps you.

uryuk commented 3 years ago

@kazukiosawa thank you very much for the answer. I will try to experiment with l2_reg more, the relationship is clear.