Closed xuanmingcui closed 1 year ago
Hi @SammyCui! Thanks for your interest in this work! I didn't perform the experiments using the distillation loss. If I were to incorporate the distillation loss into EViT, I would compute the attention score as the average of the attention from the class token to the image tokens and the attention from the distillation token to the image tokens. This is because the distillation token itself acts like a class token in computing the distillation loss. Suppose the first token is the class and the second token is the distillation token (assuming there are two loss functions: $L{CE}$ and $L{teacher}$ as in Figure 2 of the paper), I would compute something like this
cls_attn = attn[:, :, 0, 2:].mean(dim=1) # the attention from the class token to the image tokens
distill_attn = attn[:, :, 1, 2:].mean(dim=1) # the attention from the distillation token to the image tokens
attn_score = (cls_attn + distill_attn ) / 2 # average over them
_, idx = torch.topk(attn_score, left_tokens, dim=1, largest=True, sorted=True)
Hi @youweiliang Thank you so much for the detailed reply! That answers all my confusions!
warmest regards :)
Hello! Thanks for the awesome work and thanks for posting the code! I am reading the code and have a minor question in the
Attention
module underevit.py
:Meaning if there is a distillation token, it will be included in
cls_attn
? (since we just skip the cls_token itself). Where am I understanding wrong?Thanks for helping in advance:)