Open freeman-1995 opened 3 years ago
Hi, could you give me your command? The problem seems caused by the usage of talking-head attention, and I am reproducing the results. We have not checked the status of the original VIT model and other Transformer models. Currently, we study these models beyond DEIT.
@ChengyueGongR I use the code provided by ross https://github.com/rwightman/pytorch-image-models, then I random choice a image from imagenet's valid dataset. I did forward on vit_base_patch32 model, I record every transformer block's output. For each output, its shape is 1x49x768,1 is batch, 49 is sequence length, 768 is each token's dimension. I use torch.nn.CosineSimilarity() to compute similarity in 49 tokens, I got 49x48 similarity value, then average it, finally got the result above
Similar problem, would you like to share the DeIT-24 model to reproduce the result?
Hi, I think your steps are correct. I have downloaded the SWIN-BASE7-224 model, and its cos similarity is [0.26, 0.27, 0.64, 0.7] for each block. I further notice that for small networks, the similarity is increased by layer but not so large. I also check the similarity for other recent models and will add discussions in a new version of our draft. Thanks for your findings, and we will update the draft before July to discuss these findings for more models [e.g. different architectures, attention, number of layers]. For now, the conclusion is that 1) it is true that the cosine similarity in shallow networks is not such large. 2) For other architectures, the cosine similarity is not such large. 3) The similarity is increasing when the network goes deeper.
p.s. Our DEIT-B24 reported in the paper is trained by Cutmix instead of Mixup + Cutmix as in DEIT, and we will add this detail in a further version of our draft. I further test a DEIT-B24 model trained by Mixup + Cutmix, its cosine similarity is [0.43, 0.45, 0.35, 0.27, 0.27, 0.27, 0.28, 0.29, 0.31, 0.31, 0.32, 0.32, 0.35, 0.35, 0.37, 0.37, 0.39, 0.41, 0.48, 0.52, 0.61, 0.61, 0.77, 0.80], not so large as the previous one. I'm further checking this fact now.
Thanks for your suggestions again! @Andy1621 @freeman-1995
I have met this problem too, I calculate the cosine similarity the same as @freeman-1995 on DeiT-base with official pretrained weight, and the consine similarity is: layer:0,sim:0.36432449519634247 layer:1,sim:0.3189731175079942 layer:2,sim:0.31578341498970985 layer:3,sim:0.2938582096248865 layer:4,sim:0.2848378960043192 layer:5,sim:0.2654615337960422 layer:6,sim:0.26314068026840687 layer:7,sim:0.2667409982532263 layer:8,sim:0.2723067970946431 layer:9,sim:0.28992879623547196 layer:10,sim:0.45002553705126047 layer:11,sim:0.5617700926959515 I can't see the increasing similarity as layer goes deeper as your paper.
I have met this problem too, I calculate the cosine similarity the same as @freeman-1995 on DeiT-base with official pretrained weight, and the consine similarity is: layer:0,sim:0.36432449519634247 layer:1,sim:0.3189731175079942 layer:2,sim:0.31578341498970985 layer:3,sim:0.2938582096248865 layer:4,sim:0.2848378960043192 layer:5,sim:0.2654615337960422 layer:6,sim:0.26314068026840687 layer:7,sim:0.2667409982532263 layer:8,sim:0.2723067970946431 layer:9,sim:0.28992879623547196 layer:10,sim:0.45002553705126047 layer:11,sim:0.5617700926959515 I can't see the increasing similarity as layer goes deeper as your paper.
Hi, would you like to share your code about computing the cosine similarity? I can't reproduce the results. @freeman-1995 @ChuanyangZheng
Hi, I train a vit_base_patch32 model with reslution 224 on imagenet, the valid acc comes up to 73.38%, then I dump all transformer blocks' outputs, calculate cosine similarity between them as mentioned in paper, I cannot have the same result in paper, here is my result. the 0,1,2,3,4,5,6,7,8,9,10,11 is layer depth, right is the value of cosine similarity 0 ---> 0.5428178586547845 1 ---> 0.6069238793659271 2 ---> 0.30199843167793006 3 ---> 0.26388993740273486 4 ---> 0.26132955026320265 5 ---> 0.24258930458215844 6 ---> 0.20970458839482967 7 ---> 0.21119057468517677 8 ---> 0.22155304307901189 9 ---> 0.23545575648548187 10 ---> 0.2329663067175004 11 ---> 0.22496230679589768