long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[35] RelTR: Relation Transformer for Scene Graph Generation #40

Open long8v opened 2 years ago

long8v commented 2 years ago

image

paper, code

TL;DR

Details

전체적으로 3개의 아키텍쳐로 구성되어 있는데 A) feature encoder extracting the visual feature context -> Z B) entity decoder capturing visual feature context -> Q_e C) triplet decoder with subject and object branches -> Q_s, Q_o, E_t 1) subject and object queries DETR의 object query 같은 것임. d차원의 임베딩. triplet에 대한 거랑 전혀 상관없음. triplet을 표현하기 위해서 triplet encoding이라는 것이 따로 존재함.

2) Coupled Self-Attention(CSA) triplet encoding과 같은 크기의 subject encoding, object encoding을 추가함. 그리고 아래의 연산으로 CSA를 구함. image

3) Decoupled Visual Attention(DVA) visual feature에 집중하는 feature context Z를 만듦. 여기서 decouple이란 단어는 해당 object가 subject인지 object인지는 상관없기 때문임. subject branch와 object branch 둘다에서 아래와 같은 DVA 연산이 일어남. image

4) Decoupled Entity Attention(DEA) entity detection과 triplet detection 사이를 메꿔주는 역할을 한다. entity를 따로 주는 이유는 SPO 관계에 대한 제약이 없으므로 더 나은 localization을 할 것으로 기대한다. image

DEA의 ouput은 FFN을 통해서 최종 원하는 output으로 만들어진다. bbox regression은 아래 FFN을 통해서, center, width, height를 예측하게 한다.
image

DVA 레이어에서 나온 subject에 대한 attention heatmap M_s와 object attention heatmap M_o가 concat된 뒤, 아래 convolutional mask head를 통해 spatial feature vector로 바뀌게 된다. -> visual encoder가 각 S, O 뽑을 때 어딜 봤는지를 보고 relation 구함. image

Set Prediction Loss for Triplet Detection

triplet prediction을 함 y_sub, c_prd, y_obj. 이때 y_sub, y_obj는 bbox와 class 두개로 이루어짐. 이 triplet에서 GT triplet(<background-no relation-background>로 pad됨)과 bipartite matching을 통해 구해지는데 구할 때의 matching cost는 1) subject cost 2) object cost 3) predicate cost로 구성됨. subjet / object cost들은 class loss와 bbox loss로 구성되는데, class cost는 아래와 같고 image bbox cost는 아래와 같음 image

triplet prediction cost는 여기에 predicate에 대한 class cost까지 합쳐서 구해짐. 이후 이 cost들로 hungarian method로 bipartite matching을 한 뒤, loss가 구해짐. 몇 번 학습하면 아래와 같이 의미없는 triplet을 내뿜는데 이를 방지하기 위해서 추가적인 IoU based 룰을 추가함. image

sub / obj class가 background고 IoU btw. gt and pred가 threshold 이상일 때, subject나 object에 대한 loss를 부가하지 않도록 함.
image

최종적인 loss는 아마 DETR에서 나왔을 entity에 대한 loss와 triplet loss를 합쳐서 계산.

Dataset

Implementation Details

Results

image

잡 생각/질문