halleewong / ScribblePrompt

[ECCV 2024] ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Medical Image
http://scribbleprompt.csail.mit.edu/
Apache License 2.0
155 stars 15 forks source link
interactive-segmentation medical-imaging segment-anything segmentation

Spaces Colab

ScribblePrompt

Project Page | Paper | Online Demo | Video

Official implementation of ScribblePrompt: Fast and Flexible Interactive Segmentation for any Biomedical Image accepted at ECCV 2024

Hallee E. Wong, Marianne Rakic, John Guttag, Adrian V. Dalca

Updates

Overview

ScribblePrompt is an interactive segmentation tool that enables users to segment unseen structures in medical images using scribbles, clicks, and bounding boxes.

Try ScribblePrompt

Models

We provide checkpoints for two versions of ScribblePrompt:

Both models have been trained with iterative scribbles, click, and bounding box interactions on a diverse collection of 65 medical imaging datasets with both real and synthetic labels.

MedScribble Dataset

We release MedScribble, a dataset of multi-annotator scribble annotations for diverse biomedical image segmentation tasks, under ./MedScribble. See the readme for more info and ./MedScribble/tutorial.ipynb for a preview of the data.

Installation

You can install scribbleprompt in two ways:

pip install git+https://github.com/halleewong/ScribblePrompt.git

The following optional dependencies are necessary for the local demo:

pip install gradio==3.40.1

Getting Started

First, download the model checkpoints to ./checkpoints.

To run an interactive demo locally:

python demos/app.py

To instantiate ScribblePrompt-UNet and make a prediction:

from scribbleprompt import ScribblePromptUNet

sp_unet = ScribblePromptUNet()

mask = sp_unet.predict(
    image,        # (B, 1, H, W) 
    point_coords, # (B, n, 2)
    point_labels, # (B, n)
    scribbles,    # (B, 2, H, W)
    box,          # (B, n, 4)
    mask_input,   # (B, 1, H, W)
) # -> (B, 1, H, W) 

To instantiate ScribblePrompt-SAM and make a prediction:

from scribbleprompt import ScribblePromptSAM

sp_sam = ScribblePromptSAM()

mask, img_features, low_res_logits = sp_sam.predict(
    image,        # (B, 1, H, W) 
    point_coords, # (B, n, 2)
    point_labels, # (B, n)
    scribbles,    # (B, 2, H, W)
    box,          # (B, n, 4)
    mask_input,   # (B, 1, 256, 256)
) # -> (B, 1, H, W), (B, 16, 256, 256), (B, 1, 256, 256)

For best results, image should have spatial dimensions $(H,W) = (128,128)$ and pixel values min-max normalized to the $[0,1]$ range.

For ScribblePrompt-UNet, mask_input should be the logits from the previous prediction. For ScribblePrompt-SAM, mask_input should be low_res_logits from the previous prediction.

To Do

Acknowledgements

Citation

If you find our work or any of our materials useful, please cite our paper:

@article{wong2024scribbleprompt,
  title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
  author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
  journal={European Conference on Computer Vision (ECCV)},
  year={2024},
}

License

Code for this project is released under the Apache 2.0 License