long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[146] Transformer Interpretability Beyond Attention Visualization #158

Open long8v opened 5 months ago

long8v commented 5 months ago
image

paper, code

TL;DR

Details

Method

image

transformer의 각 attention head에 대해서 LRP-based relevance를 구하고 이걸 gradient랑 결합해서 class-specific visualization을 하는게 목표

Relevance and gradients

chain rule로 n번째 레이어의 input x의 index j인 $xj^{(n)}$의 $y{t}$ (class t에 대한 model의 output)의 gradient는 아래와 같이 정의됨

image

여기서 $L^{(n)}(X, Y)$를 두개의 텐서 X, Y에 대한 layer operation이라고 정의하면 두 텐서는 feature map / weight가 되고 이걸 Deep Taylor Decomposition을 따르도록 하면 relevance는 아래와 같이 구해진다.

image

deep taylor decomposition은 taylor 근사를 통해 relevance를 구하는 방식 http://arxiv.org/abs/1512.02479

image

걍 gradient로 저 output을 근사한다~ 정도만 이해

conservative rule에 따라 n 번째의 레이어의 relevance의 합과 (n - 1)번째의 relevance의 합은 같아져야함

image

이것도 위의 논문에서 나오는데 f(x)가 relevance의 sum이랑 같아지는걸 의미.

image

LRP 논문에서는 activation으로 ReLU를 상정했기 때문에 positive value만 보도록 함

image

그런데 GeLU 같은걸 쓰게 되면 negative도 나올 수 있게 됨 그래서 이걸 걍 postivie subset인 애들에 대해서 구하는걸로 변경 (...? 이게 뭐가 달라지는지 모르겠음)

image image

그리고 맨 처음 초기화는 class t에 대한 one-hot 벡터로 설정

non parametric relevance propagation

transformer에는 두개의 feature map tensor가 mixing되는 두개의 연산이 있음 (<=> 아까는 weight와 feature map이었는데 이와는 다름) (1) skip connection (2) matrix multiplication임

두개의 tensor u, v가 있을 때 두 operator를 아래와 같이 정의함

image

앞에껀 u에 대한 relevance score / 뒤에껀 v에 대한 relevance score skip connection에서 relevance score가 너무 커지는 경향성이 있었음. 이를 해결하기 위해 normalization을 추가해줌

image

Relevance and gradient diffusion

image

self-attention연산에 대해서 위의 절차로 구하게 되면 아래와 같이 됨

image

$A^{(b)}$ : b번째 block의 attention map $E_h$: heads dimension 에 대해 평균 gradient의 positive 부분만 남김.

rollout의 경우 단순히 attention map에 대해서 iterative하게 곱하는 형식

image

Obtaining the image relevance map

결과적으로 나오는건 s x s의 matrix C 각 row는 해당 토큰이 다른 토큰에 대한 relevance map 이 연구에서는 분류 모델에 대해서만 집중했으니까 [CLS]토큰에 대한 relevance score만 구함 ViT의 경우에는 sequence 길이 s에서 [CLS] 는 빼고 $\sqrt{s-1} \times \sqrt{s-1}$ 이렇게 resize 한다음에 interpolate해서 구함

Result

중요도가 높다고 한 애들을 masking하는 방식으로 해서 top-1 accuracy가 어떻게 바뀌는지 확인. positive는 중요한 애들을 점점 지워나가는 것(낮으면 좋음)