Official implementation of ScribblePrompt: Fast and Flexible Interactive Segmentation for any Biomedical Image accepted at ECCV 2024
Hallee E. Wong, Marianne Rakic, John Guttag, Adrian V. Dalca
ScribblePrompt is an interactive segmentation tool that enables users to segment unseen structures in medical images using scribbles, clicks, and bounding boxes.
We provide checkpoints for two versions of ScribblePrompt:
ScribblePrompt-UNet with an efficient fully-convolutional architecture
ScribblePrompt-SAM based on the Segment Anything Model
Both models have been trained with iterative scribbles, click, and bounding box interactions on a diverse collection of 65 medical imaging datasets with both real and synthetic labels.
We release MedScribble, a dataset of multi-annotator scribble annotations for diverse biomedical image segmentation tasks, under ./MedScribble
. See the readme for more info and ./MedScribble/tutorial.ipynb
for a preview of the data.
You can install scribbleprompt
in two ways:
pip install git+https://github.com/halleewong/ScribblePrompt.git
git clone https://github.com/halleewong/ScribblePrompt
python -m pip install -r ./ScribblePrompt/requirements.txt
export PYTHONPATH="$PYTHONPATH:$(realpath ./ScribblePrompt)"
The following optional dependencies are necessary for the local demo:
pip install gradio==3.40.1
First, download the model checkpoints to ./checkpoints
.
To run an interactive demo locally:
python demos/app.py
To instantiate ScribblePrompt-UNet and make a prediction:
from scribbleprompt import ScribblePromptUNet
sp_unet = ScribblePromptUNet()
mask = sp_unet.predict(
image, # (B, 1, H, W)
point_coords, # (B, n, 2)
point_labels, # (B, n)
scribbles, # (B, 2, H, W)
box, # (B, n, 4)
mask_input, # (B, 1, H, W)
) # -> (B, 1, H, W)
To instantiate ScribblePrompt-SAM and make a prediction:
from scribbleprompt import ScribblePromptSAM
sp_sam = ScribblePromptSAM()
mask, img_features, low_res_logits = sp_sam.predict(
image, # (B, 1, H, W)
point_coords, # (B, n, 2)
point_labels, # (B, n)
scribbles, # (B, 2, H, W)
box, # (B, n, 4)
mask_input, # (B, 1, 256, 256)
) # -> (B, 1, H, W), (B, 16, 256, 256), (B, 1, 256, 256)
For best results, image
should have spatial dimensions $(H,W) = (128,128)$ and pixel values min-max normalized to the $[0,1]$ range.
For ScribblePrompt-UNet, mask_input
should be the logits from the previous prediction. For ScribblePrompt-SAM, mask_input
should be low_res_logits
from the previous prediction.
Code for ScribblePrompt-SAM builds on Segment Anything. Thanks to Meta AI for open-sourcing the model.
We use functions from voxsynth for applying random deformations during scribble simulation
If you find our work or any of our materials useful, please cite our paper:
@article{wong2024scribbleprompt,
title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
journal={European Conference on Computer Vision (ECCV)},
year={2024},
}
Code for this project is released under the Apache 2.0 License