xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

how to compute feature map variances? #6

Closed LostXine closed 2 years ago

LostXine commented 2 years ago

hello,

Thank you for your great work!

I wonder how you get the feature map variances. According to my understanding, you first need to extract representations of all the samples, which should give us a vector with a length of D (let's just fatten the 2d tensor or concatenate all tokens). Then you calculate the variance of each element in this vector over all the samples, which should give us D variances. Finally, you take the mean value of all D variances and get the variance ready to report.

Did I get you correctly? Sorry if I didn't catch up with your existing documentation or description.

Thank you and I'm looking forward to your reply.

Best,

xxxnell commented 2 years ago

Hi, thank you for your support!

Yes, you explained correctly. I calculated the variance of the tokens per channel and then averaged them.

I will release the code after re-implementing it. I can't give a definite timeline, but I’ll try hard to release the whole code for the feature map variance by this Friday!

LostXine commented 2 years ago

Thanks for the prompt response! Have a nice day

xxxnell commented 2 years ago

I just released featuremap_variance.ipynb (Colab notebook). I'm sorry for the late release.

LostXine commented 2 years ago

Thank you so much!