long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

🍅 짭짤이 논문 모아보기 (XAI) #167

Closed long8v closed 1 month ago

long8v commented 3 months ago

Network Dissection

Quantifying Interpretability of Deep Visual Representations

https://arxiv.org/pdf/1704.05796.pdf

unit이 어떻게 이렇게 label이 달렸나가 궁금했는데 선행연구(https://arxiv.org/pdf/1412.6856.pdf)에서 part에 대한 annotation을 단 거 였음..

image
long8v commented 3 months ago

ALBEF

image image image
long8v commented 3 months ago

Label Words are Anchors: An Information Flow Perspective for Understanding In-Context Learning

In-Context Learning 에서 shallow layer에서 Label이 관련된 정보를 aggregate하고 LLM의 Final prediction에서 다시 그 Label들을 참고해서 예측한다는 관찰

image
long8v commented 3 months ago

Contrastive Explanations for Model Interpretability

image

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveExplanation(nn.Module):
    def __init__(self, model):
        super(ContrastiveExplanation, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, fact_class, foil_class):
        # 입력 표현 생성
        hiddens = self.model.encoder(input_ids, attention_mask)[0]

        # 최종 예측을 위한 선형 변환
        logits = self.model.classifier(hiddens[:, 0, :])

        # Fact와 Foil 클래스에 해당하는 가중치 추출
        W_fact = self.model.classifier.weight[fact_class]
        W_foil = self.model.classifier.weight[foil_class]

        # 대조 방향 계산
        u = W_fact - W_foil

        # 대조 공간으로 투영
        P_u = u.unsqueeze(0) @ u.unsqueeze(1) / torch.norm(u)**2
        h_contrast = hiddens @ P_u.squeeze(1)

        # 토큰별 중요도 계산
        token_importance = torch.sum(hiddens * h_contrast, dim=-1)
        token_importance = token_importance.masked_fill(attention_mask == 0, float('-inf'))
        token_importance = F.softmax(token_importance, dim=-1)

        return token_importance

# 설명 생성 예시
model = ContrastiveExplanation(pretrained_model)
input_ids = torch.tensor([[101, 2023, 2003, 1037, 3231, 102]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
fact_class = torch.tensor(0)
foil_class = torch.tensor(1)

token_importance = model(input_ids, attention_mask, fact_class, foil_class)
print(token_importance)
long8v commented 3 months ago

TIFA: Accurate and Interpretable Text-to-Image Faithfulness Evaluation with Question Answering

link

image image image
long8v commented 2 months ago

VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation

이미지와 생성된 이미지가 있을 때 score와 함께 설명까지 내보낼 수 있는 task를 제안

image image

생성 모델에 특화된 느낌.

image

perceptual quality와 semantic consistency로 나누어서 측정. 나로 따지면 semantic consistency만 관심 있음.

image

CLIPScore가 GPT4-V보다 human correlation이 매우 낮았다고 함. 일단 CLIP이 synthetic을 못본 이유도 있을 것 같고, 여기서 설명하기에는 text 전체를 abstraction으로 봐서 그렇다고 하넹

long8v commented 2 months ago

Attention is not Explanation

RNN에서의 attention score를 보고 해석하는건 한계가 있다. 가령 gradient based와 Attention based는 양상이 다르고, output이 똑같지만 attention map은 분포가 매우 다른 counter attention map 같은걸 만들 수 있음. 결론 부분만 보면 Recurrent가 아닌 구조에서 leave-one-out (LOO) 방법론은 gradient based 방법론이랑 Correlation이 높다고 함.

image

RNN이 아닌 다른 FFN에서는 이런 경향이 안일어났다고 함

image
long8v commented 2 months ago

Optimizing Relevance Maps of Vision Transformers Improves Robustness

image

ViT가 분류할 때 spurious cue에 집중(배경을 보고 분류)하는 경향이 있었고, 이를 해결하기 위해 relevance map을 활용해서 foreground에 집중해서 분류하도록 했더니 ood 성능이 좋아졌다는 내용