chongzhou96 / EdgeSAM

Official PyTorch implementation of "EdgeSAM: Prompt-In-the-Loop Distillation for On-Device Deployment of SAM"
https://mmlab-ntu.com/project/edgesam/
Other
938 stars 42 forks source link
coreml on-device-ai segment-anything

EdgeSAM

Prompt-In-the-Loop Distillation for On-Device Deployment of SAM

Chong Zhou1, Xiangtai Li1, Chen Change Loy1*, Bo Dai2

(*corresponding author)

1S-Lab, Nanyang Technological University, 2Shanghai Artificial Intelligence Laboratory

[Paper] [Project Page] [Hugging Face Demo] [iOS App]

https://github.com/chongzhou96/EdgeSAM/assets/15973859/fe1cd104-88dc-4690-a5ea-ff48ae013db3

Watch the full live demo video: [YouTube] [Bilibili]

Updates

Overview

EdgeSAM is an accelerated variant of the Segment Anything Model (SAM), optimized for efficient execution on edge devices with minimal compromise in performance. It achieves a 40-fold speed increase compared to the original SAM, and outperforms MobileSAM, being 14 times as fast when deployed on edge devices while enhancing the mIoUs on COCO and LVIS by 2.3 and 3.2 respectively. EdgeSAM is also the first SAM variant that can run at over 30 FPS on an iPhone 14.

compare

In this figure, we show the encoder throughput of EdgeSAM compared with SAM and MobileSAM as well as the mIoU performance on the SA-1K dataset (sampled from SA-1B) with box and point prompts.

Approach Our approach involves distilling the original ViT-based SAM image encoder into a purely CNN-based architecture, better suited for edge devices. We carefully benchmark various distillation strategies and demonstrate that task-agnostic encoder distillation fails to capture the full knowledge embodied in SAM. To overcome this bottleneck, we include both the prompt encoder and mask decoder in the distillation process, with box and point prompts in the loop, so that the distilled model can accurately capture the intricate dynamics between user input and mask generation.

arch

Performance | Method | Train Set | COCO AP | COCO APs | COCO APm | COCO APl | GFLops | MParam. | FPS iPhone 14 | FPS 2080 Ti | FPS 3090 | |-------------|-----------|---------|---------------------|---------------------|---------------------|--------|---------|---------------|-------------|----------| | SAM | SA-1B | 46.1 | 33.6 | 51.9 | 57.7 | 2734.8 | 641.1 | - | 4.3 | - | | FastSAM | 2% SA-1B | 37.9 | 23.9 | 43.4 | 50.0 | 887.6 | 68.2 | - | - | 25.0* | | MobileSAM | 1% SA-1B | 39.4 | 26.9 | 44.4 | 52.2 | 38.2 | 9.8 | 4.9 | 103.5 | 100.0* | | EdgeSAM | 1% SA-1B | 42.2 | 29.6 | 47.6 | 53.9 | 22.1 | 9.6 | 38.7 | 164.3 | - | | EdgeSAM-3x | 3% SA-1B | 42.7 | 30.0 | 48.6 | 54.5 | 22.1 | 9.6 | 38.7 | 164.3 | - | | EdgeSAM-10x | 10% SA-1B | 43.0 | 30.3 | 48.9 | 55.1 | 22.1 | 9.6 | 38.7 | 164.3 | - | *In this table, we report the mask mAP on the COCO dataset. ViTDet-H is used as the detector, whose box mAP is 58.7, to provide box prompts. For speed benchmarking, we infer both the encoder and decoder (with a single prompt). FLOPs are calculated based on the 1024x1024 input resolution. Numbers denoted by * are copied from MobileSAM. 3x and 10x represent training with more data. Here, we do not apply an additional mask refinement iteration per the setting of the original SAM paper.*

Table of Contents

Installation

The code requires python>=3.8 and we use torch==2.0.0 and torchvision==0.15.1. Please refer to the official PyTorch installation instructions.

  1. Clone the repository locally:
git clone https://github.com/chongzhou96/EdgeSAM.git && cd EdgeSAM
  1. Install additional dependencies:
pip install -r requirements.txt
  1. Install EdgeSAM:
pip install -e .

Usage

  1. Download checkpoints (please refer to Checkpoints for more details about the PyTorch and CoreML checkpoints):
mkdir weights
wget -P weights/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam.pth
wget -P weights/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x.pth
  1. You can easily incorporate EdgeSAM into your Python code with following lines:
from edge_sam import SamPredictor, sam_model_registry
sam = sam_model_registry["edge_sam"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

Since EdgeSAM follows the same encoder-decoder architecture as SAM, their usages are very similar. One minor difference is that EdgeSAM allows outputting 1, 3, and 4 mask candidates for each prompt, while SAM yields either 1 or 3 masks. For more details, please refer to the example Jupyter Notebook.

Train and Eval

Please refer to README_TRAIN.md for more details.

Web Demo

After installing EdgeSAM and downloading the checkpoints. You can start an interactive web demo with the following command:

python web_demo/gradio_app.py

By default, the demo is hosted on http://0.0.0.0:8080/ and expects edge_sam_3x.pth to be stored in the weights/ folder. You can change the default behavior by:

python web_demo/gradio_app.py --checkpoint [CHECKPOINT] --server-name [SERVER_NAME] --port [PORT]

Since EdgeSAM can run smoothly on a mobile phone, it's fine if you don't have a GPU.

We've deployed the same web demo in the Hugging Face Space [link]. However, since it uses the CPU as the backend and is shared by all users, the experience might not be as good as a local deployment. Really appreciate the Hugging Face team for supporting us with the GPU!

Speed up the web demo with ONNX backend

  1. Install the onnxruntime with pip install onnxruntime if your machine doesn't have a GPU or pip install onnxruntime-gpu if it does (but don't install both of them). Our implementation is tested under version 1.16.3.

  2. Download the ONNX models to the weights/ folder:

wget -P weights/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_encoder.onnx
wget -P weights/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_decoder.onnx
  1. Start the demo:
python web_demo/gradio_app.py --enable-onnx
  1. Navigate to http://0.0.0.0:8080 in your browser.

CoreML / ONNX Export

CoreML

We provide a script that can export a trained EdgeSAM PyTorch model to two CoreML model packages, one for the encoder and another for the decoder. You can also download the exported CoreML models at Checkpoints.

For encoder:

python scripts/export_coreml_model.py [CHECKPOINT]

For decoder:

python scripts/export_coreml_model.py [CHECKPOINT] --decoder --use-stability-score

Since EdgeSAM doesn't perform knowledge distillation on the IoU token of the original SAM, its IoU predictions might not be reliable. Therefore, we use the stability score for mask selection instead. You can stick to the IoU predictions by removing --use-stability-score.

The following shows the performance reports of the EdgeSAM CoreML models measured by Xcode on an iPhone 14 (left: encoder, right: decoder):

![xcode](https://github.com/chongzhou96/EdgeSAM/assets/15973859/8df54f76-24c9-4ad2-af6d-086b971d073b)

Known issues and model descriptions As of `coremltools==7.1`, you may encounter the assertion error during the export, e.g., `assert len(inputs) <= 3 or inputs[3] is None`. One workaround is to comment out this assertion following the traceback path, e.g., `/opt/anaconda3/envs/EdgeSAM/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py line 1573`. Since CoreML doesn't support interpolation with dynamic target sizes, the converted CoreML models do not contain the pre-processing, i.e., resize-norm-pad, and the post-processing, i.e., resize back to the original size. The encoder takes a `1x3x1024x1024` image as the input and outputs a `1x256x64x64` image embedding. The decoder then takes the image embedding together with point coordinates and point labels as the input. The point coordinates follow the `(height, width)` format with the top-left corner as the `(0, 0)`. The choices of point labels are `0: negative point`, `1: positive point`, `2: top-left corner of box`, and `3: bottom-right corner of box`.

ONNX

Similar to the CoreML export, you can use the following commands to export the encoder and the decoder to ONNX models respectively:

For encoder:

python scripts/export_onnx_model.py [CHECKPOINT]

For decoder:

python scripts/export_onnx_model.py [CHECKPOINT] --decoder --use-stability-score

Checkpoints

Please download the checkpoints of EdgeSAM from its Hugging Face Space (all the EdgeSAM variants only differ in the number of training images):

Model COCO mAP PyTorch CoreML ONNX
SAM 46.1 - - -
EdgeSAM 42.1 Download [Encoder] [Decoder] [Encoder] [Decoder]
EdgeSAM-3x 42.7 Download [Encoder] [Decoder] [Encoder] [Decoder]
EdgeSAM-10x 43 TBA TBA TBA

Note: You need to unzip the CoreML model packages before usage.

iOS App

We are planning to release the iOS app that we used in the live demo to the App Store. Please stay tuned!

Acknowledgements

This study is supported under the RIE2020 Industry Alignment Fund Industry Collaboration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s). We are grateful to Han Soong Chong for his effort in the demonstration application.

We appreciate the following projects, which enable EdgeSAM: SAM, MobileSAM, FastSAM, TinyViT, and RepViT.

Citation

@article{zhou2023edgesam,
  title={EdgeSAM: Prompt-In-the-Loop Distillation for On-Device Deployment of SAM},
  author={Zhou, Chong and Li, Xiangtai and Loy, Chen Change and Dai, Bo},
  journal={arXiv preprint arXiv:2312.06660},
  year={2023}
}

License

This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.