dhkim0225 / 1day_1paper

read 1 paper everyday (only weekday)
54 stars 1 forks source link

[74] Not All Patches are What You Need: Expediting Vision Transformers via Token Reorganizations (EVIT) #103

Open dhkim0225 opened 2 years ago

dhkim0225 commented 2 years ago

paper code

attention 좀 더 효율적으로 수행하자. (필요한 것만 쓰자!) image

Token Reorganization

image token 들을 identify (background or object) 하고, fusing 하는 방법. image

Attentive Token Identification

n 을 ViT 의 encoder 의 input token 개수라 하자. [CLS] token 과 나머지 token 간의 관계는 일반적으로 사용하는 attention 에서 값들을 가져올 수 있다. 관계가 많이 연결되는 애들이 중요한 애들 아닐까? 하는 motivation!

일반적으로 [CLS] token 구하는 식 한 번만 더 보고 가자. x_class == [CLS] token a == attention vector image

attentive 를 구하기 위해서, attn = mean(attn) 을 수행해 준다. (attention head 는 12 개니까 평균을 내준다.) image 이 값을 갖고, top-k 개를 attentive 로 둔다.

이것 만으로는 부족하다. DeiT-S 에서 (4, 7, 10) layer 에서 inattentive token 들을 지워나가니, acc 가 확확 떨어지더라. image 그래서 혼합하는 방법을 생각해 냈다.

InAttentive Token Fusion

image 그냥, inattentive 한 애들은 weighted average 를 해서 다음 layer 로 넘겨준다. 즉, block 지날 때마다 patch 가 줄어드는 거다.

code

from https://github.com/youweiliang/evit/blob/0999f090edbcb6dea095546b5faeb2750beaf88b/vision_transformer.py#L307-L314

            cls_attn = attn[:, :, 0, 1:]  # [B, H, N-1]
            cls_attn = cls_attn.mean(dim=1)  # [B, N-1]
            _, idx = torch.topk(cls_attn, left_tokens, dim=1, largest=True, sorted=True)  # [B, left_tokens]
            # cls_idx = torch.zeros(B, 1, dtype=idx.dtype, device=idx.device)
            # index = torch.cat([cls_idx, idx + 1], dim=1)
            index = idx.unsqueeze(-1).expand(-1, -1, C)  # [B, left_tokens, C]

            return x, index, idx, cls_attn, left_tokens

from https://github.com/youweiliang/evit/blob/0999f090edbcb6dea095546b5faeb2750beaf88b/vision_transformer.py#L350-L358

            if self.fuse_token:
                compl = complement_idx(idx, N - 1)  # [B, N-1-left_tokens]
                non_topk = torch.gather(non_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C))  # [B, N-1-left_tokens, C]

                non_topk_attn = torch.gather(cls_attn, dim=1, index=compl)  # [B, N-1-left_tokens]
                extra_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True)  # [B, 1, C]
                x = torch.cat([x[:, 0:1], x_others, extra_token], dim=1)
            else:
                x = torch.cat([x[:, 0:1], x_others], dim=1)

필자 의견

  1. global 하게 보고, top-k 를 뽑아내는 게 성능이 더 좋지 않을까 생각해 본다.
    1. pruning 도 global 하게 pruning 하는게 잘 되지 않았는가. (structured 든, unstructured 든)
  2. hierarchical transformer 에 대한 성능이 어떨 지 궁금하다. (swin 등)

Result

visualize

inattentive token 들을 visualize 하면 다음과 같다. image

ImageNet

모델 별 성능. image

DeIT-S 에 inattentive fusion 하냐, 안하냐 에 따른 차이 image

pretrained DeiT-S 를 oracle 로 두어서 실험해 봄. 일종의 distillation 처럼 생각할 수 있음. DeiT-S 가 무슨 token 이 중요한 지만 뽑아서 알려주는 것임. image

Dynamic ViT 와의 성능 비교. pretrained ==> model initialize 를 pretrained 로 했다는 뜻 image