zhaoyang-0204 / gnp

gradient norm penalty
Apache License 2.0
38 stars 4 forks source link

Question about figure3 in the paper #1

Open JayC1208 opened 3 months ago

JayC1208 commented 3 months ago

Hi, I am walking through the experiments via codes, and find hard to understand the result of Figure 3 in "Penalizing Gradient Norm for Efficiently Improving Generalization in Deep Learning".

In this figure, the gradient norm when $\alpha$ is 0.8 is above zero and seems bit large compared to other cases, while the testing error rate remains low. Your paper suggest that it is good to have smaller gradient norm as it indicates flat minima and this result is counter intuitive.

Can you give more explanation about this?

zhaoyang-0204 commented 3 months ago

Hello!

Gradient norm is an intriguing and somewhat mysterious attribute in deep learning. Defined as the magnitude of the gradient, it can be interpreted as a metric indicating the overall curvature of the loss surface. As far as I know, explicitly elucidating the correlation between gradient norm and model generalization remains challenging, particularly in practical deep learning applications.

In some sense, you may regard the gradient norm penalty as analogous to the weight penalty commonly used in deep learning. Typically, we expect to confine our search space within a narrow weight region to facilitate a more rapid convergence during the search process. However, it is imperative to avoid excessive weight penalties, as they may lead to suboptimal minima that do not meet the requirements of the task. For the explicitly penalizing gradient norm in our paper, we aim to converge towards minima with a flatter loss landscape. Conversely, an over-regularization can potentially steer the training towards bad minima, where the focus is disproportionately on the gradient norm of the loss surface, overshadowing the task requirements. So, we should set an appropriate regularization effect during the training. Through empirical observations, we have found that setting α=0.8 yields optimal performance for our training scenarios.

Hope this can give you some intuition.