long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[44] Context-Aware Scene Graph Generation With Seq2Seq Transformers #50

Open long8v opened 1 year ago

long8v commented 1 year ago

image

paper

TL;DR

Details

Architecture

image

Object Encoder

그냥 트랜스포머 인코더. 근데 input으로 뭘 넣어줬다는지 잘 모르겠음. 그냥 visual feature map이려나? $X_b$는 b번째 트랜스포머 block의 output

Relationship Decoder

contextualized object features $XB\in \mathbb{R}^{N\times D}$(N은 object 개수고 D는 임베딩 차원인듯)와 그 전 step까지 예측된 relationship $\hat Y{1:m}$을 받아서 m(+1)번째 relationship을 뽑는 일을 함.

이때 decoder의 input은 subject의 contextualized embedding과 object의 contextualized embedding, 이전에 뽑힌 relation에 대한 임베딩값을 concat해서 들어감. $(X_B[i], E[r], X_B[j])$ 그러니까 이전에 예측한 걸 임베딩해서 넣어주면 다음거가 나오는 특이한 구조임. concat한걸 D차원을 ffn 하고 self-attention, cross-attention을 통과함. 처음에는 그냥 D차원짜리 <SOS>를 넣어줌. cross-attention의 경우에 decoder의 self-attention으로 나온 $Y_k$와 encoder에서 나온 $X_B$랑 걸어줘서 나옴. image

마지막 K번째 decoder layer의 output $Y_K$를 가지고 다음 relationship triplet을 예측함. 모든 남은 pair에 대해서 아래와 같이 예측함. 그리고 softmax가 가장 높은 것이 선택됨. image

$i$ : subject indices, $j$ : object indicies

Training scheme

Reinforcement Learning

1) training시에는 input history를 GT로 받지만(teacher-forcing) inference에서는 그렇지 않음 2) cross entropy loss와 recall 사이의 gap이 있음. -> 디코딩 할 때 강화학습 요소를 추가하자. recall과 mRecall은 반대로 움직이는 성향이 있음. 그래서 alpha 추가하여 reward로 정의함. image

image

여기서 action은 모든 pair에 대해서 logit 값이 나왔을 때 어떤걸 선택할지. state는 m개를 선택한 상태. RL 적용하니 greedy decoding보다 나았다.

Expreiments

image

Qualitative Results

image

independent하게 예측하는 것보다 gt 맞출 확률이 높아졌다.