zhaoyang-0204 / gnp

gradient norm penalty
Apache License 2.0
38 stars 4 forks source link

Gradient Regularization in Deep Learning

Table of Contents

Works Related

  1. "Penalizing Gradient Norm for Efficiently Improving Generalization in Deep Learning"[ICML2022], by Yang Zhao, Hao Zhang and Xiuyuan Hu.
  2. "When Will Gradient Regularization Be Harmful?"[ICML2024], by Yang Zhao, Hao Zhang and Xiuyuan Hu.

Upgrade [2024.6.15]

  1. JAX Framework Update: Upgraded the training framework to the latest version (JAX 0.4.28).
  2. New Paper Implementation: Integrated the implementation of our latest research paper into this repository.
  3. Additional Model Architectures:: Included Swin and CaiT Transformer architectures in the model list.

Training using this repo

Short intro

1.Overview

Basically, gradient regularization (GR) could be understood as gradient norm penalty, where an additional term regarding the gradient norm $||\nabla_{\theta} L(\theta)||_2$ will be added on top of the empirical loss,

$$\begin{aligned} L(\theta) = L{\mathcal{S}}(\theta) + \lambda ||\nabla{\theta} L_{\mathcal{S}}(\theta)||_2 \end{aligned}$$

Gradient norm is considered as a key property that could characterize the flatness of the loss surface. By penalizing the gradient norm, the optimization is encouraged to converge to flatter minima on the loss surface. This results in improved model generalization.

2. Practical Gradient Computation of Gradient Norm

Based on the chain rule, the gradient of the gradient norm is given by:

$$ \nabla{\theta} L(\theta) = \nabla{\theta} L{\mathcal{S}}(\theta) + \lambda \cdot \nabla{\theta}^2 L{\mathcal{S}}(\theta) \frac{\nabla{\theta} L{\mathcal{S}}(\theta)}{||\nabla{\theta} L_{\mathcal{S}}(\theta)||} $$

Computing the gradient of this gradient norm term directly involves the full computation of the Hessian matrix. To address this, we use a Taylor expansion to approximate the multiplication between the Hessian matrix and vectors, resulting in:

$$\begin{split} \nabla{\theta} L(\theta) & = \nabla{\theta} L{\mathcal{S}}(\theta) + \lambda \cdot (\frac{\nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L{\mathcal{S}}(\theta)||}) - \nabla{\theta}L\mathcal{S}(\theta)}{r}) \ & = (1 - \frac{\lambda}{r}) \nabla{\theta} L{\mathcal{S}}(\theta) + \frac{\lambda}{r} \cdot \nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L_{\mathcal{S}}(\theta)||}) \end{split}$$

where $r$ is a small scalar value. So, we need to set two parameters for gradient norm penalty $\lambda$, one for the penalty coefficient and the other one for $r$. For practical convenience, we will further set,

$$\begin{split} \nabla{\theta} L(\theta) = (1 - \alpha) \nabla{\theta} L{\mathcal{S}}(\theta) + \alpha \cdot \nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L_{\mathcal{S}}(\theta)||}), ~~~\alpha = \frac{\lambda}{r} \end{split}$$

Notably, the SAM algorithm is a special implementation of this scheme where $\alpha$ is always set to 1.

3. Be Careful When using Gradient Regularization with Adaptive Optimizer

GR can lead to serious performance degeneration in the specific scenarios of adaptive optimization.

Error Rate[Cifar10]
Model Adam Adam + GR Adam + GR + Zero-GR-Warmup
ViT-Ti 14.82 13.92 13.61
ViT-S 12.07 12.40 10.68
ViT-B 10.83 12.36 9.42

With both our empirical observations and theoretical analysis, we find that the biased estimation introduced in GR can induce the instability and divergence in gradient statistics of adaptive optimizers at the initial stage of training, especially with a learning rate warmup technique which originally aims to benefit gradient statistics.

To mitigate this issue, we draw inspirations from the idea of warmup techniques, and propose three GR warmup strategies: $\lambda$-warmup, $r$-warmup and zero-warmup GR. ach of the three strategies can relax the GR effect during warmup course in certain ways to ensure the accuracy of gradient statistics. See paper for details.

End

If you find this helpful, you could cite the papers as


@inproceedings{zhao2022penalizing,
  title={Penalizing gradient norm for efficiently improving generalization in deep learning},
  author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
  booktitle={International Conference on Machine Learning},
  pages={26982--26992},
  year={2022},
  organization={PMLR}
}

@inproceedings{zhaowill,
  title={When Will Gradient Regularization Be Harmful?},
  author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
  booktitle={Forty-first International Conference on Machine Learning}
}