long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[147] Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers #159

Open long8v opened 7 months ago

long8v commented 7 months ago
image

paper, code

TL;DR

Details

some notation

Relevancy initialization

relevancy map을 초기화 / 업데이트 할 거임

image

SA 전에는 서로 상호작용이 없어서 $R^{ii}$, $R^{tt}$는 identity. $R^{it}$는 zero tensor.

Relevancy update rules

attention map A를 가지고 relavancy를 update할 것임 전작에 따라 head 간 평균을 구하고 gradient를 사용

image

여기서 $\delta A$는 우리가 시각화하고 싶은 class t에 대한 output인 $y_t$를 A로 미분한 것. 평균을 취하기 전에 positive만 남겨줌(clamp)(이에 대한 이유는 딱히 없고 전작을 따라줌)

image image

self attention에 대한 relevance 업데이트 방식은 아래와 같음 여기서 s는 query token, q는 key token임.

여기서 $R^{xx}$는 두개로 분리할 수 있는데 처음에 초기화한 $I$랑 $I$를 뺀 residual인 $\hat{R}^{xx}$임. $\hat{R}^{xx}$는 gradient를 사용하기 때문에 숫자가 절대적으로 작음. 이를 해결하기 위해 row의 합이 1이 되도록 정규화 해줌.

image

co-attention / cross-attention의 경우 update rule을 아래와 같이 정의해줌

image

Obtaining classification relevancies

[CLS] 토큰의 row에 해당하는 relevancy map을 보면 되는데 text 에 대한걸 보려면 $R^{tt}$의 첫번째 row를 보면 되고 image에 대한걸 보려면 $R^{ti}$의 첫번째 row를 보면 됨

Adaptation to attention type

image

Result

image image image image image
long8v commented 7 months ago

see more CLIP score

원래 논문은 CLIP은 다루지 않았는데 누가 기말과제로 올려놨다는 듯. 로직은 대충 이렇다.

long8v commented 5 months ago

더 이해하기 쉬운 pseudo-code

    def interpret(self, image, texts, model, CLS_idx, device):
        batch_size = 1
        inputs = self.preprocess(text=texts, images=image, padding="max_length", return_tensors="pt")
        inputs = inputs.to(device)
        outputs = model(**inputs, output_attentions=True)
        clip_score = outputs.logits_per_image
        image_attn_blocks = outputs.vision_model_output.attentions
        text_attn_blocks = outputs.text_model_output.attentions
        index = [i for i in range(batch_size)]
        model.zero_grad()

        num_tokens = text_attn_blocks[0].shape[-1]
        R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].dtype).to(device)
        R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
        for i, attn_map in enumerate(text_attn_blocks):
            attn_map_grad = torch.autograd.grad(logits, [attn_map], retain_graph=True)[0].detach()
            attn_map = attn_map.detach()
            attn_map = attn_map.reshape(-1, cam.shape[-1], cam.shape[-1])
            attn_map_grad = attn_map_grad.reshape(-1, grad.shape[-1], grad.shape[-1])
            attn_map = attn_map * attn_map_grad
            attn_map = attn_map.reshape(batch_size, -1, attn_map.shape[-1], attn_map.shape[-1])
            attn_map = attn_map.clamp(min=0).mean(dim=1) 
            R_text = R_text + torch.bmm(cam, R_text)
        text_relevance = R_text
        return R_text[CLS_idx, 1:CLS_idx]