xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
808 stars 79 forks source link

Potential mistake in loss landscape visualization. #10

Closed sjtuytc closed 2 years ago

sjtuytc commented 2 years ago

Hi, thanks for your great work. I'd like to discuss the L2 Loss problem in loss landscape visualization. I found that your calculated L2 loss is significantly larger (10x) than the classification loss so the landscape visualization is basically a visualization of L2 Loss. In fact, "weight decay" is slightly different from "L2 Loss" in Pytorch in implementation. Simply calculating the sum of norms as L2 loss is different from applying weight decays in Adam-like fancy optimizers in Pytorch. See blogs in https://bbabenko.github.io/weight-decay/. Although one might find L2 Loss is significantly larger than the classification loss. In fact, in the practice of ViT, the weight decay loss does not dominate the classification loss, this is due to the implementation of weight decay in Pytorch.

sjtuytc commented 2 years ago

@xxxnell

xxxnell commented 2 years ago

Hi @sjtuytc ,

Thank you for your thoughtful feedback, and I agree with both of your points. First, l2 regularization dominates loss landscape visualizations. So in my humble opinion, we need to be careful to make claims using only loss landscape visualizations, and other analyses can be helpful. Second, as you pointed out and as I mentioned in the Appendix, “in the strict sense, the weight decay is not l2 regularization, but we neglect the difference" for the sake of simplicity. Thanks again for your support!

sjtuytc commented 2 years ago

Sure, I agree with you! Thanks for your great work!