long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[43] Relation Transformer Network #49

Open long8v opened 1 year ago

long8v commented 1 year ago
image

paper

TL;DR

Details

why two-stage?

이전에 읽었던 one-stage와 달리 bbox가 input으로 들어감! 즉 최종 output에 bbox가 있냐 없냐의 차이.

Architecture

image

Problem Decomposition

Object Detection

VGG16 backbone의 Faster RCNN을 쓸거다. node $ni$가 주어졌을 때 spatial embdding인 bbox coordinate을 뽑고, VGG16의 가장 위 레이어에서 feature map $I{feature}$를 뽑아 4096 차원의 feature vector $v_i$들을 뽑는다.
$o_i^{init}\in\mathbb{R}^C$의 경우 (C는 # of classes), GloVE 임베딩을 사용해서 초기화해줬다.

Encoder N2N Attention

object들의 context를 학습하는 것은 object detection 분 아니라 relation 분류를 할 때도 도움이 된다. 이러한 목적으로 object 들을 트랜스포머 인코더에 넣는다. 이때 들어가는 input은 아래와 같다. $v_i$는 image feature vector $s_i$는 object detection에서 나온 class label에 대한 GloVe 벡터 $b_i$는 bounding box다.

image

그리고 아래의 네트워크를 태운다. 그리고 마지막 레이어에서 해당 bbox에 대한 분류를 하게 된다.

image

또 $f_i^{final}$은 decoder cross-attention으로 들어가게 된다.

Decoder Edge Positional Encoding

transformer decoder에 Edge query를 넣어주고 싶은데, 사실 edge에 대해서는 순서가 없기 때문에 PE를 넣기가 애매하다.

image

그래서 위와 같이하면 뭐가 subject고 object인지 알 수 있게 된다. (?? 사실 식이 이해가 잘 안됨)

Decoder E2N Attention

Edge Queries내 $e_{ij}$랑 위의 부수적인 벡터

image image

edge간 self-attention 하는건 성능 향상에 도움이 안돼서 바로 cross-attention하는 방식으로 진행이 되었다.

Directed Relation Prediction Module(RPM)

relation은 방향이 있어서, 위에서 rich한 임베딩들을 가져와서 아래와 같이 directional relational embedding을 만들어주었다.

image

그리고 RPM(relation prediction module)이라고 하는 모듈에 위의 임베딩을 넣어줘서 최종 relation을 예측한다.

image

LayerNorm -> Linear -> ReLU -> Linear(최종 relation categories) -> softmax 취해주기

frequency에 대한 값도 넣어줬다.

image

Implementation Details

object detector에서 NMS(IoU > 0.3)를 통해 나온 top 64개의 object label을 사용했고, relation classification의 계산비용을 줄이기 위해 노드 페어 중 바운딩 박스가 겹치는 부분만 고려했다!(엄청 큰 inductive bias..)

Result

Visual Genome

image

Qualitative

image

(b)는 N2N attention heatmap. 한 object가 다른 object에 얼마나 영향을 줬냐. (c)는 E2N attention heatmap. object 들이 relation에 얼마나 영향을 줬냐.

on을 of로 틀리는 경우가 많았는데 보면 Face, of, Woman처럼 of가 더 자연스러운 경우가 많았다. -> multi-predict가 필요한 이유인듯

image

decoder 그냥 transformer꺼 쓰는 것(self-attention 들어간거)보다 우리꺼 쓰는게 나았다.

image

decoder에서 각 feature를 뺐을 때의 성능 drop은 아래와 같았다. frequency가 좀 쎈듯.

각각에 대한 저자들의 설명 image