Welcome to the official repository for the method presented in "LAVT: Language-Aware Vision Transformer for Referring Image Segmentation."
Code in this repository is written using PyTorch and is organized in the following way (assuming the working directory is the root directory of this repository):
./lib
contains files implementing the main network../lib
, _utils.py
defines the highest-level model, which incorporates the backbone network
defined in backbone.py
and the simple mask decoder defined in mask_predictor.py
.
segmentation.py
provides the model interface and initialization functions../bert
contains files migrated from Hugging Face Transformers v3.0.2,
which implement the BERT language model.
We used Transformers v3.0.2 during development but it had a bug that would appear when using DistributedDataParallel
.
Therefore we maintain a copy of the relevant source files in this repository.
This way, the bug is fixed and code in this repository is self-contained../train.py
is invoked to train the model../test.py
is invoked to run inference on the evaluation subsets after training../refer
contains data pre-processing code and is also where data should be placed, including the images and all annotations.
It is cloned from refer. ./data/dataset_refer_bert.py
is where the dataset class is defined../utils.py
defines functions that track training statistics and setup
functions for DistributedDataParallel
.April 13th, 2023. Using the Dice loss instead of the cross-entropy loss can improve results. Will add code and release weights later when get a chance.
June 21st, 2022. Uploaded the training logs and trained model weights of lavt_one.
June 9th, 2022. Added a more efficient implementation of LAVT.
--model
as lavt_one
(and lavt
is still valid for specifying the old model).
The rest of the configuration stays unchanged.DistributedDataParallel
needs to be applied only once.
Applying it twice (on the standalone language model and the main branch)
as done in the old implementation led to low GPU utility,
which slowed down training.
We recommend training this model on 8 GPUs
(and same as before with batch size 32).The code has been verified to work with PyTorch v1.7.1 and Python 3.7.
Clone this repository.
Change directory to root of this repository.
Create a new Conda environment with Python 3.7 then activate it:
conda create -n lavt python==3.7
conda activate lavt
Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example):
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch
Install the packages in requirements.txt
via pip
:
pip install -r requirements.txt
Follow instructions in the ./refer
directory to set up subdirectories
and download annotations.
This directory is a git clone (minus two data files that we do not need)
from the refer public API.
Download images from COCO.
Please use the first downloading link 2014 Train images [83K/13GB], and extract
the downloaded train_2014.zip
file to ./refer/data/images/mscoco/images
.
./pretrained_weights
directory where we will be storing the weights.
mkdir ./pretrained_weights
pth
file in ./pretrained_weights
.
These weights are needed for training to initialize the model../checkpoints
directory where we will be storing the weights.
mkdir ./checkpoints
./checkpoints
.RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) |
---|
RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) |
---|---|---|---|
log | weights | log | weights | log | weights | log | weights |
test.py
,
because only one out of multiple annotated expressions is
randomly selected and evaluated for each object during training.
But these numbers give a good idea about the test performance.
The two should be fairly close.We use DistributedDataParallel
from PyTorch.
The released lavt
weights were trained using 4 x 32G V100 cards (max mem on each card was about 26G).
The released lavt_one
weights were trained using 8 x 32G V100 cards (max mem on each card was about 13G).
Using more cards was to accelerate training.
To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node:
mkdir ./models
mkdir ./models/refcoco
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco --model_id refcoco --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco/output
mkdir ./models/refcoco+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco+ --model_id refcoco+ --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco+/output
mkdir ./models/gref_umd
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy umd --model_id gref_umd --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_umd/output
mkdir ./models/gref_google
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy google --model_id gref_google --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_google/output
lavt
and lavt_one
. See Updates.refcoco
, refcoco+
, and refcocog
.umd
identifies the UMD partition and google
identifies the Google partition../models/[args.model_id]/output
and the best checkpoint will be saved as ./checkpoints/model_best_[args.model_id].pth
.tiny
, small
, base
, and large
. The default is base
../models/[args.model_id]
directory via mkdir
before running train.py
.
This is because we use tee
to redirect stdout
and stderr
to ./models/[args.model_id]/output
for logging.
This is a nuisance and should be resolved in the future, i.e., using a proper logger or a bash script for initiating training.For RefCOCO/RefCOCO+, run one of
python test.py --model lavt --swin_type base --dataset refcoco --split val --resume ./checkpoints/refcoco.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
python test.py --model lavt --swin_type base --dataset refcoco+ --split val --resume ./checkpoints/refcoco+.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
val
, testA
, and testB
.For G-Ref (UMD)/G-Ref (Google), run one of
python test.py --model lavt --swin_type base --dataset refcocog --splitBy umd --split val --resume ./checkpoints/gref_umd.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
python test.py --model lavt --swin_type base --dataset refcocog --splitBy google --split val --resume ./checkpoints/gref_google.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
umd
or google
.val
and test
for the UMD partition, and only val
for the Google partition..Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU |
---|---|---|---|---|---|---|---|
RefCOCO val | 84.46 | 80.90 | 75.28 | 64.71 | 34.30 | 72.73 | 74.46 |
RefCOCO test A | 88.07 | 85.17 | 79.90 | 68.52 | 35.69 | 75.82 | 76.89 |
RefCOCO test B | 79.12 | 74.94 | 69.17 | 59.37 | 34.45 | 68.79 | 70.94 |
RefCOCO+ val | 74.44 | 70.91 | 65.58 | 56.34 | 30.23 | 62.14 | 65.81 |
RefCOCO+ test A | 80.68 | 77.96 | 72.90 | 62.21 | 32.36 | 68.38 | 70.97 |
RefCOCO+ test B | 65.66 | 61.85 | 55.94 | 47.56 | 27.24 | 55.10 | 59.23 |
G-Ref val (UMD) | 70.81 | 65.28 | 58.60 | 47.49 | 22.73 | 61.24 | 63.34 |
G-Ref test (UMD) | 71.54 | 66.38 | 59.00 | 48.21 | 23.10 | 62.09 | 63.62 |
G-Ref val (Goog.) | 71.16 | 67.21 | 61.76 | 51.98 | 27.30 | 60.50 | 63.66 |
- We have validated LAVT on RefCOCO with multiple runs. The overall IoU on the val set generally lies in the range of 72.73±0.5%.
lavt_one
).Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU |
---|---|---|---|---|---|---|---|
RefCOCO val | 85.87 | 82.13 | 76.64 | 65.45 | 35.30 | 73.50 | 75.41 |
RefCOCO test A | 88.47 | 85.63 | 80.57 | 68.84 | 35.71 | 75.97 | 77.31 |
RefCOCO test B | 80.20 | 76.49 | 70.34 | 60.12 | 34.94 | 69.33 | 71.86 |
RefCOCO+ val | 76.19 | 72.27 | 66.82 | 56.87 | 30.15 | 63.79 | 67.65 |
RefCOCO+ test A | 82.50 | 79.44 | 74.00 | 63.27 | 31.99 | 69.79 | 72.53 |
RefCOCO+ test B | 68.03 | 63.35 | 57.29 | 47.92 | 26.98 | 56.49 | 61.22 |
G-Ref val (UMD) | 75.82 | 71.06 | 63.99 | 52.98 | 27.31 | 64.02 | 67.41 |
G-Ref test (UMD) | 76.12 | 71.13 | 64.58 | 53.62 | 28.03 | 64.49 | 67.45 |
G-Ref val (Goog.) | 72.57 | 68.65 | 63.09 | 53.33 | 28.14 | 61.31 | 64.84 |
You can run inference on any image-text pair
and visualize the result by running the script ./demo_inference.py
.
Have fun!
@inproceedings{yang2022lavt,
title={LAVT: Language-Aware Vision Transformer for Referring Image Segmentation},
author={Yang, Zhao and Wang, Jiaqi and Tang, Yansong and Chen, Kai and Zhao, Hengshuang and Torr, Philip HS},
booktitle={CVPR},
year={2022}
}
We appreciate all contributions. It helps the project if you could
Code in this repository is built upon several public repositories. Specifically,
Some of these repositories in turn adapt code from OpenMMLab and TorchVision. We'd like to thank the authors/organizations of these repositories for open sourcing their projects.
GNU GPLv3