youweiliang / evit

Python code for ICLR 2022 spotlight paper EViT: Expediting Vision Transformers via Token Reorganizations
Apache License 2.0
162 stars 19 forks source link

Is the distillation token included in attentive patch selection? #15

Closed xuanmingcui closed 1 year ago

xuanmingcui commented 1 year ago

Hello! Thanks for the awesome work and thanks for posting the code! I am reading the code and have a minor question in the Attention module under evit.py:

        left_tokens = N - 1
        if self.keep_rate < 1 and keep_rate < 1 or tokens is not None:  # double check the keep rate
            left_tokens = math.ceil(keep_rate * (N - 1))
            if tokens is not None:
                left_tokens = tokens
            if left_tokens == N - 1:
                return x, None, None, None, left_tokens
            assert left_tokens >= 1
            cls_attn = attn[:, :, 0, 1:]  # [B, H, N-1] --------> _we skip the first (cls_token) and take the rest?_
            cls_attn = cls_attn.mean(dim=1)  # [B, N-1]
            _, idx = torch.topk(cls_attn, left_tokens, dim=1, largest=True, sorted=True)

Meaning if there is a distillation token, it will be included in cls_attn? (since we just skip the cls_token itself). Where am I understanding wrong?

Thanks for helping in advance:)

youweiliang commented 1 year ago

Hi @SammyCui! Thanks for your interest in this work! I didn't perform the experiments using the distillation loss. If I were to incorporate the distillation loss into EViT, I would compute the attention score as the average of the attention from the class token to the image tokens and the attention from the distillation token to the image tokens. This is because the distillation token itself acts like a class token in computing the distillation loss. Suppose the first token is the class and the second token is the distillation token (assuming there are two loss functions: $L{CE}$ and $L{teacher}$ as in Figure 2 of the paper), I would compute something like this

cls_attn = attn[:, :, 0, 2:].mean(dim=1)  # the attention from the class token to the image tokens
distill_attn = attn[:, :, 1, 2:].mean(dim=1)  # the attention from the distillation token to the image tokens
attn_score = (cls_attn + distill_attn ) / 2  # average over them
_, idx = torch.topk(attn_score, left_tokens, dim=1, largest=True, sorted=True)
xuanmingcui commented 1 year ago

Hi @youweiliang Thank you so much for the detailed reply! That answers all my confusions!

warmest regards :)