Environment Setup: This repository is built using the JAX framework. Begin by setting up the Python environment specified in the requirements.txt
file.
Configuration: The config
folder contains all the configuration flags and their default values. You can add custom flags if needed by modifying these files.
Model Architectures: The model
folder includes various model architectures such as VGG, ResNet, WideResNet, PyramidNet, ViT, Swin and CaiT. To add custom models, follow the Flax model template and register your model using the _register_model
function in this folder.
Dataset Pipeline: The ds_pipeline
folder provides the dataset pipeline, based primarily on the SAM repository. Unlike SAM, this repo uses local ImageNet data instead of tensorflow_datasets. Specify the path to your local dataset folders, ensuring the folder structure is:
ImageNet folder
└───n01440764
│ │ *.JPEG
│
└───n01443537
│ │ *.JPEG
...
Optimizers: The optimizer
folder contains the optimizers, including SGD (Momentum), AdamW and RMSProp. You can add custom optimizers by modifying these files.
Training Recipes: The recipe
folder contains .sh
files, each corresponding to a specific model's training script. To run a training script, use the following command:
bash wideresnet-cifar.sh
Alternatively, to deploy configurations directly (ensuring the config flag is in the config file), use:
python3 -m gnp.main.main --config=the-train-config-py-file --working_dir=your-output-dir --config.config-anything-else-here
Basically, gradient regularization (GR) could be understood as gradient norm penalty, where an additional term regarding the gradient norm $||\nabla_{\theta} L(\theta)||_2$ will be added on top of the empirical loss,
$$\begin{aligned} L(\theta) = L{\mathcal{S}}(\theta) + \lambda ||\nabla{\theta} L_{\mathcal{S}}(\theta)||_2 \end{aligned}$$
Gradient norm is considered as a key property that could characterize the flatness of the loss surface. By penalizing the gradient norm, the optimization is encouraged to converge to flatter minima on the loss surface. This results in improved model generalization.
Based on the chain rule, the gradient of the gradient norm is given by:
$$ \nabla{\theta} L(\theta) = \nabla{\theta} L{\mathcal{S}}(\theta) + \lambda \cdot \nabla{\theta}^2 L{\mathcal{S}}(\theta) \frac{\nabla{\theta} L{\mathcal{S}}(\theta)}{||\nabla{\theta} L_{\mathcal{S}}(\theta)||} $$
Computing the gradient of this gradient norm term directly involves the full computation of the Hessian matrix. To address this, we use a Taylor expansion to approximate the multiplication between the Hessian matrix and vectors, resulting in:
$$\begin{split} \nabla{\theta} L(\theta) & = \nabla{\theta} L{\mathcal{S}}(\theta) + \lambda \cdot (\frac{\nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L{\mathcal{S}}(\theta)||}) - \nabla{\theta}L\mathcal{S}(\theta)}{r}) \ & = (1 - \frac{\lambda}{r}) \nabla{\theta} L{\mathcal{S}}(\theta) + \frac{\lambda}{r} \cdot \nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L_{\mathcal{S}}(\theta)||}) \end{split}$$
where $r$ is a small scalar value. So, we need to set two parameters for gradient norm penalty $\lambda$, one for the penalty coefficient and the other one for $r$. For practical convenience, we will further set,
$$\begin{split} \nabla{\theta} L(\theta) = (1 - \alpha) \nabla{\theta} L{\mathcal{S}}(\theta) + \alpha \cdot \nabla{\theta}L\mathcal{S}(\theta +r\frac{\nabla{\theta}L{\mathcal{S}}(\theta)}{||\nabla{\theta}L_{\mathcal{S}}(\theta)||}), ~~~\alpha = \frac{\lambda}{r} \end{split}$$
Notably, the SAM algorithm is a special implementation of this scheme where $\alpha$ is always set to 1.
GR can lead to serious performance degeneration in the specific scenarios of adaptive optimization.
Model | Adam | Adam + GR | Adam + GR + Zero-GR-Warmup |
---|---|---|---|
ViT-Ti | 14.82 | 13.92 | 13.61 |
ViT-S | 12.07 | 12.40 |
10.68 |
ViT-B | 10.83 | 12.36 |
9.42 |
With both our empirical observations and theoretical analysis, we find that the biased estimation introduced in GR can induce the instability and divergence in gradient statistics of adaptive optimizers at the initial stage of training, especially with a learning rate warmup technique which originally aims to benefit gradient statistics.
To mitigate this issue, we draw inspirations from the idea of warmup techniques, and propose three GR warmup strategies: $\lambda$-warmup, $r$-warmup and zero-warmup GR. ach of the three strategies can relax the GR effect during warmup course in certain ways to ensure the accuracy of gradient statistics. See paper for details.
If you find this helpful, you could cite the papers as
@inproceedings{zhao2022penalizing,
title={Penalizing gradient norm for efficiently improving generalization in deep learning},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={International Conference on Machine Learning},
pages={26982--26992},
year={2022},
organization={PMLR}
}
@inproceedings{zhaowill,
title={When Will Gradient Regularization Be Harmful?},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={Forty-first International Conference on Machine Learning}
}