xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

How to plot the Hessian max eigenvalue spectra? #12

Open Dong1P opened 2 years ago

Dong1P commented 2 years ago

I read your paper and studied a lot. I would also like to see the code for plotting Hessian max eigenvalue spectra. May I know if you have any plans to update?

Best,

xxxnell commented 2 years ago

Hi @Dong1P ,

Thank you for your support. I did not release the code for the Hessian eigenvalue spectra visualization (e.g., Fig 1c and 4) yet. Instead, I provide some useful information below.

Hessian Max Eigenvalue Spectrum: My implementation uses PyHessian (https://github.com/amirgholami/PyHessian) and the pseudo-code below is extremely simple.

Appendix A3 in https://arxiv.org/abs/2105.12639

Source: Appendix A3 in Blurs Behave Like Ensembles: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness (ICML 2022).

It calculates and gathers top-k (e.g., top-5) Hessian eigenvalues by using power iteration mini-batch wisely.

from pyhessian import hessian
from tqdm import tqdm

max_eigens = []  # a list of batch-wise top-k hessian max eigenvalues
model = model.cuda()
for xs, ys in tqdm(dataset_train):
    hessian_comp = hessian(model, data=(xs, ys), transform=transform, weight_decay=weight_decay, cuda=True)  # measure hessian max eigenvalues with NLL + L2 on data augmented (`transform`) datasets
    top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=5)  # collect top-5 hessian eigenvaues by using power-iteration (https://en.wikipedia.org/wiki/Power_iteration)
    max_eigens = max_eigens + top_eigenvalues  # aggregate top-5 max eigenvalues

PyHessian does not support transform and weight_decay as arguments by default, so it's better to modify the code for more rigorous results.

Visualization: Hessian spectra (a list of real values, i.e., max_eigens) are visualized by using kernel density estimation (https://en.wikipedia.org/wiki/Kernel_density_estimation). See also https://scikit-learn.org/stable/modules/density.html.

yukimmmmiao commented 2 years ago

Thanks for your great work and I have learned a lot from it.@xxxnell However, I have a few questions about the Hessian max eigenvalue spectra. I wanna know that the NN weights w in your pseudo-code is fixed as the trained weights or not. And why you only gather top-5 largest Hessian eigenvalues? In my opinion, those minus eigenvalues(the smallest) also play an important role in the loss landscape.

xxxnell commented 2 years ago

Hi @yukimmmmiao , thank you for the kind words.

I assumed that the largest Hessian values have a dominant influence on optimization (Ghorbani, et al (ICML 2019). See also Liu et al (NeurIPS 2020)). I agree that the smallest Hessian eigenvalues also play an important role in optimization---to be clear, the algorithm will produce the greatest eigenvalues in absolute value, so the Hessian spectrum contains not only the largest eigenvalues but also the smallest negative eigenvalues. However, this algorithm neglect near-zero Hessian values, and I would like to leave a detailed analysis of near-zero Hessian values for future work. In my code, NN weights are fixed values. The Hessian values were measured by using saved checkpoints in separate jobs, not in the optimization tasks, for simplicity.

dgcnz commented 2 months ago

Hi @xxxnell, do you have any tips on what arguments to use for the pyhessian.eigenvalues function? If I just specify top_n=5 I never get negative eigenvalues, and by playing around with the parameters I've been able to find some when setting tol=1e-1 and top_n=50. Could you please share the parameters you used for your paper? Thanks! :)

xxxnell commented 2 months ago

Hi @dgcnz, thank you for reaching out. The occurrence of negative Hessian eigenvalues is largely dependent on the dataset and model configuration. I was wondering that you're working with smaller datasets, e.g. CIFAR, with data augmentations and utilizing a small model, e.g. Ti-sized model.

dgcnz commented 2 months ago

Thanks for your answer, @xxxnell 😄. We're currently testing on Rotational MNIST, which as far as I understand, would be too small/easy to consistently find negative eigenvalues?

Also, the datasets you tested for obtaining negative hessian eigenvalues was 10% of CIFAR and ImageNet, right? Did you by any chance test on a smaller dataset?

For context, we're comparing a CNN with an Rotationally Equivariant CNN and we were hoping to find a similar pattern as your work for a ViT vs ResNet.

xxxnell commented 2 months ago

Unfortunately, I haven't tested on datasets smaller/easier than CIFAR. The conf top_n=5 or 10 is for CIFAR, and I believe that having some level of difficulty in the tasks contributes to the observation of negative values. Based on my intuition, although we might expect consistent behaviors on smaller datasets, the difficulty of tasks can influence observations; it's possible that negative Hessians might not be observed in tasks that are too easy. Consequently, using a higher top_n might make more sense in order to observe them.

Please feel free to reach out via email (namuk.park@gmail.com or parkn4@gene.com) if you'd like to provide more detailed information about your settings. I'd be happy to discuss at some point.