rlqja1107 / torch-ST-SGG

Official PyTorch implementation Source code for Adaptive Self-Training Framework for Fine-grained Scene Graph generation (ST-SGG), accepted at ICLR 2024
18 stars 0 forks source link

[ICLR 2024] Adaptive Self-training Framework for Fine-grained Scene Graph Generation

This repository is the official code for the paper [ICLR'24] Adaptive Self-training Framework for Fine-grained Scene Graph Generation

Overview

img

Addressing the issue: Inherent Noise in Dataset

Proposed Framework (ST-SGG)

Installation

INSTALL.md

Dataset

DATASET.md

Pre-trained Object Detector

As ST-SGG framework is composed of two-stage, please download the pre-trained detector from following source. Following the most of previous studies, we use the Faster R-CNN as the detector. For SGDet task, the detector trained on VG-50 dataset extracts the 80 proposals including the object class distribution and visual feature. Please put the pre-trained detector's model (vg_faster_det.pth) on Pretrained_model directory.

Training

Here, we mainly describe the implementation for Motif model on PredCls task. However, the other models (e.g., VCTree, BGNN, and HetSGG) can be easily implemented in a similar way. Likewise, the SGCls and SGDet task are executed by changing the predcls to sgcls or sgdet in the shell line. To utilize multiple GPUs for training, you need to modify the variable mutli_gpu=false to mutli_gpu=true in each shell. To train the model with ST-SGG, please change the pre-trained model path (PRETRAINED_WEIGHT_PATH) in shell.

Vanilla

# Pretrain the model
bash run_shell/motif/predcls/vanilla/pretrain.sh
# Re-train with ST-SGG framework
bash run_shell/motif/predcls/vanilla/train_stsgg.sh

Resampling (bilvl)

# Pretrain the model
bash run_shell/motif/predcls/bilvl/pretrain.sh
# Re-train with ST-SGG framework
bash run_shell/motif/predcls/bilvl/train_stsgg.sh

Debiasing - Internal transfer (IE-Trans)

Instead of running the code for I-Trans (i.e., pretrain.sh, internal.sh, external.sh, and combine.sh), you can download the dataset file adopted the I-Trans in DATASET.md

# Pretrain the model for
bash run_shell/motif/predcls/ie_trans/relabel/pretrain.sh
# Internal Transfer / External Transfer
bash run_shell/motif/predcls/ie_trans/relabel/internal.sh
bash run_shell/motif/predcls/ie_trans/relabel/external.sh
# In combine.sh, external transfer is excluded
bash run_shell/motif/predcls/ie_trans/relabel/combine.sh
# Train the I-Trans model
bash run_shell/motif/predcls/ie_trans/train.sh
# Re-train the I-Trans model with ST-SGG framework
bash run_shell/motif/predcls/ie_trans/train_stsgg.sh

For the SGCls or SGDet task, there is no requirement to execute the pretrain.sh, internal.sh, external.sh, combine.sh shell scripts, as they are used for dataset pre-processing.

# Train the I-Trans model
bash run_shell/motif/sgdet/ie_trans/train.sh
# Re-train the I-Trans model with ST-SGG framework
bash run_shell/motif/sgdet/ie_trans/train_stsgg.sh

Evaluation

# Evaluate the trained model
bash run_shell/evaluation.sh

If you want to evaluate the specific checkpoints, put the iteration number of trained model in checkpoint_list variable in run_shell/evaluation.sh shell.

For evaluating our pre-trained models, please refer to MODEL.md

Directory Structure for Shell files

run_shell  
├── evaluation.sh  
├── evaluation4pretrained_model.sh
├── motif 
│   ├── predcls
│   │    ├── vanilla
│   │    │    ├── pretrain.sh
│   │    │    └── train_stsgg.sh
│   │    ├── bilvl  
│   │    │    ├── pretrain.sh
│   │    │    └── train_stsgg.sh
│   │    ├── ie_trans  
│   │    │    ├── relabel
│   │    │    │    ├── combine.sh
│   │    │    │    ├── external.sh
│   │    │    │    ├── internal.sh
│   │    │    │    ├── pretrain.sh
│   │    │    ├── train_stsgg.sh
│   │    │    └── train.sh  
│   │    │
│   ├── sgcls
│   │   ...
│   │
│   ├── sgdet
│   │   ...
├── bgnn
│   ├── predcls
│   │   ...

Citation

@article{kim2024adaptive,
  title={Adaptive Self-training Framework for Fine-grained Scene Graph Generation},
  author={Kim, Kibum and Yoon, Kanghoon and In, Yeonjun and Moon, Jinyoung and Kim, Donghyun and Park, Chanyoung},
  journal={arXiv preprint arXiv:2401.09786},
  year={2024}
}

Acknowledgement

The code is developed on top of Scene-Graph-Benchmark.pytorch, IE-Trans, and BGNN.