Wentong Li*, Yuqian Yuan*, Song Wang, Wenyu Liu, Dongqi Tang, Jian Liu, Jianke Zhu, and Lei Zhang
Paper | NeurIPS2023
We propose a method named APro, designed to generate precise soft pseudo labels online for unlabeled regions within segmentation networks.
This branch focuses on the task of Weakly box-supervised Instance Segmentation and is built upon the SOLOv2 and Mask2former frameworks, adhering to the BoxInstSeg repository guidelines. Multiple instance segmentation experiments are reproduced to verify the effectiveness of our APro method on Pascal VOC and COCO.
🌟Our APro method includes global affinity propagation and local affinity propagation. The
code can be found in apro
, the usage of APro can be found below.
💡Our APro method can be seamlessly plugged into the existing segmentation networks for various tasks to achieve the weakly-supervised segmentation with label-efficient sparse annotations.
This is built on the MMdetection (V2.25.0). Please refer to Installation and Getting Started for the details of installation and basic usage. We also recommend the user to refer the office introduction of MMdetection.
First, compile gp_cuda
op for the global affinity propagation.
cd apro/gp_cuda
python setup.py build develop
Then, import global&local affinity propagation and MinimumSpanningTree from apro.apro
.
from apro.apro import Global_APro, Local_APro
from apro.gp_cuda.mst.mst import MinimumSpanningTree
global_apro = Global_APro()
local_apro = Local_APro(kernel_size=5, zeta_s=0.15) #set kernel_size and zeta_s
mst = MinimumSpanningTree(Global_APro.norm2_distance)
First, build a minimum spanning tree based on the input image.
img_mst_tree = mst(image)
Then, call the function global_apro
can get the soft_pseudo.
soft_pseudo = global_apro(mask_pred, image, img_mst_tree, sigma=0.01)
You can also use the deep feature feat
.
soft_pseudo = global_apro(soft_pseudo, feat, img_mst_tree, sigma=0.07)
The loss for global one can be calculated, taking box supervision as example:
loss_global_term = torch.abs(soft_pseudo - mask_pred) * box_mask_target
box_regions = box_mask_target.sum((1, 2, 3)).clamp(min=1)
loss_global_term = loss_global_term.sum((1, 2, 3)) / box_regions
loss_global = loss_global_term
soft_pseudo = local_apro(image, mask_pred, box_mask_target)
The loss for global one can be calculated, box supervised as example:
loss_local_term = torch.abs(mask_pred - soft_pseudo) * box_mask_target
loss_local_term = loss_local_term.sum((1, 2, 3)) / box_regions
loss_local = loss_local_term
Backbone | Epoch | Models | AP | AP50 | AP75 |
---|---|---|---|---|---|
ResNet-50 | 36 | model | 38.4 | 65.4 | 39.8 |
ResNet-101 | 36 | model | 40.5 | 67.9 | 42.6 |
Backbone | Epoch | Models | AP | AP50 | AP75 |
---|---|---|---|---|---|
ResNet-50 | 36 | model | 32.9 | 55.2 | 33.6 |
ResNet-101 | 36 | model | 34.3 | 57.0 | 35.3 |
Backbone | Epoch | Models | AP | AP50 | AP75 |
---|---|---|---|---|---|
ResNet-50 | 50 | model | 42.3 | 70.6 | 44.5 |
ResNet-101 | 50 | model | 43.6 | 72.0 | 45.7 |
Swin-L | 50 | model | 49.6 | 77.6 | 53.1 |
Backbone | Epoch | Models | AP | AP50 | AP75 |
---|---|---|---|---|---|
ResNet-50 | 50 | model | 36.1 | 62.0 | 36.7 |
ResNet-101 | 50 | model | 38.0 | 63.6 | 38.7 |
Swin-L | 50 | model | 41.0 | 68.3 | 41.9 |
This branch is built based on BoxInstSeg and MMdetection.
@inproceedings{APro,
title={Label-efficient Segmentation via Affinity Propagation},
author={Wentong Li, Yuqian Yuan, Song Wang, Wenyu Liu, Dongqi Tang, Jian Liu, Jianke Zhu and Lei Zhang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}