Closed long8v closed 6 months ago
In-Context Learning 에서 shallow layer에서 Label이 관련된 정보를 aggregate하고 LLM의 Final prediction에서 다시 그 Label들을 참고해서 예측한다는 관찰
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveExplanation(nn.Module):
def __init__(self, model):
super(ContrastiveExplanation, self).__init__()
self.model = model
def forward(self, input_ids, attention_mask, fact_class, foil_class):
# 입력 표현 생성
hiddens = self.model.encoder(input_ids, attention_mask)[0]
# 최종 예측을 위한 선형 변환
logits = self.model.classifier(hiddens[:, 0, :])
# Fact와 Foil 클래스에 해당하는 가중치 추출
W_fact = self.model.classifier.weight[fact_class]
W_foil = self.model.classifier.weight[foil_class]
# 대조 방향 계산
u = W_fact - W_foil
# 대조 공간으로 투영
P_u = u.unsqueeze(0) @ u.unsqueeze(1) / torch.norm(u)**2
h_contrast = hiddens @ P_u.squeeze(1)
# 토큰별 중요도 계산
token_importance = torch.sum(hiddens * h_contrast, dim=-1)
token_importance = token_importance.masked_fill(attention_mask == 0, float('-inf'))
token_importance = F.softmax(token_importance, dim=-1)
return token_importance
# 설명 생성 예시
model = ContrastiveExplanation(pretrained_model)
input_ids = torch.tensor([[101, 2023, 2003, 1037, 3231, 102]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
fact_class = torch.tensor(0)
foil_class = torch.tensor(1)
token_importance = model(input_ids, attention_mask, fact_class, foil_class)
print(token_importance)
이미지와 생성된 이미지가 있을 때 score와 함께 설명까지 내보낼 수 있는 task를 제안
생성 모델에 특화된 느낌.
perceptual quality와 semantic consistency로 나누어서 측정. 나로 따지면 semantic consistency만 관심 있음.
CLIPScore가 GPT4-V보다 human correlation이 매우 낮았다고 함. 일단 CLIP이 synthetic을 못본 이유도 있을 것 같고, 여기서 설명하기에는 text 전체를 abstraction으로 봐서 그렇다고 하넹
RNN에서의 attention score를 보고 해석하는건 한계가 있다. 가령 gradient based와 Attention based는 양상이 다르고, output이 똑같지만 attention map은 분포가 매우 다른 counter attention map 같은걸 만들 수 있음. 결론 부분만 보면 Recurrent가 아닌 구조에서 leave-one-out (LOO) 방법론은 gradient based 방법론이랑 Correlation이 높다고 함.
RNN이 아닌 다른 FFN에서는 이런 경향이 안일어났다고 함
ViT가 분류할 때 spurious cue에 집중(배경을 보고 분류)하는 경향이 있었고, 이를 해결하기 위해 relevance map을 활용해서 foreground에 집중해서 분류하도록 했더니 ood 성능이 좋아졌다는 내용
Network Dissection
Quantifying Interpretability of Deep Visual Representations
https://arxiv.org/pdf/1704.05796.pdf
unit이 어떻게 이렇게 label이 달렸나가 궁금했는데 선행연구(https://arxiv.org/pdf/1412.6856.pdf)에서 part에 대한 annotation을 단 거 였음..