facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.9k stars 5.67k forks source link

How to use onnx model to inference a image and get all mask? #649

Open xpp726 opened 11 months ago

xpp726 commented 11 months ago

I want to all object mask from image in using onnx model, but the example for onnx inference is not my expectation.Who knows how to use the onnx model to reason the way to get all the objects in the image, please feel free to teach me, thank you!!

913832344 commented 11 months ago

I ask it too !

zoldaten commented 9 months ago
  1. change in segment-anything\scripts\export_onnx_model.py

    parser.add_argument(
    "--return-single-mask",
    default=False,
    #action="store_true",
  2. convert pth model to onnx (b model for example) python scripts/export_onnx_model.py --checkpoint sam_vit_b_01ec64.pth --model-type vit_b --output vit_b.onnx

  3. use to extract multiple masks

import cv2,time from segment_anything import sam_model_registry, SamPredictor,SamAutomaticMaskGenerator, sam_model_registry from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime from onnxruntime.quantization import QuantType from onnxruntime.quantization.quantize import quantize_dynamic

checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" sam = sam_model_registrymodel_type

onnx_model_path='vit_b.onnx' ort_session = onnxruntime.InferenceSession(onnx_model_path) sam.to(device='cuda')

start_time = time.time()

img = cv2.imread("near.jpg") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(img) print(len(masks))

print("--- %s seconds ---" % (time.time() - start_time))



see my repo with onnx models - https://github.com/zoldaten/segment_anything_onnx_models
Altaflux commented 9 months ago

@zoldaten I do not seem to understand your example. You create an onnx inference session but never use it, and instead you use the normal sam model to get the masks.

wasn't the request to use the onnx model itself to get all the object masks?

zoldaten commented 9 months ago

@Altaflux u may compare the result with original sam inference:

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import time,cv2

start_time = time.time()

sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
img = cv2.imread("snapshot_5.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

masks = mask_generator.generate(img)

print("--- %s seconds ---" % (time.time() - start_time))
#--- 89.85638666152954 seconds ---

and check len(masks)

heyoeyo commented 9 months ago

The SAM repo doesn't include support for exporting the image encoder to onnx, however there is a discussion of this in issue #16 and one of the user's there has a repo that can apparently do this (along with a blog post explaining it in detail): https://github.com/AndreyGermanov/sam_onnx_full_export

I haven't seen anyone update the auto-mask generator to use the onnx models however, so that would still need to be done manually.

Altaflux commented 9 months ago

Hey @heyoeyo yep I have been using the exported onnx image encoder (the one you linked) and the onnx decoder to try create an auto-mask generator that supports the onnx model. The problem I see is that there are differences between what inputs the normal pytorch model supports vs what inputs the onnx model supports.

At first glance it seems that the onnx model doesn't support batches, which I am having trouble understanding as it is not really documented anywhere.

heyoeyo commented 9 months ago

I haven't used onnx with batches myself, but at least from that link it does look like the onnx model expects a batch dimension. For example, the onnx inference notebook prints out the input_tensor.shape showing (1, 3, 1024, 1024). It's hard to link to it direct, but its the 'Out[14]:' block in the notebook. Although maybe it only supports single entries?

If it does only support single batch entries, I think you could loop over the batch items, since batches are meant to be processed by the model independently anyways.