facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.41k stars 5.61k forks source link

Is it possible to run ONNX decoder with multiple boxes? #308

Open VladMVLX opened 1 year ago

VladMVLX commented 1 year ago

I am experimenting with onnx decoder trying to run it with multiple boxes as:

point = index * 2;
point_coords[0, point, 0] = rect.Left 
point_coords[0, point, 1] = rect.Top
point_coords[0, point + 1, 0] = rect.Right
point_coords[0, point + 1, 1] = rect.Bottom
point_labels[0, point]     = 2
point_labels[0, point + 1] = 3

where index is index of a box, rect is the box by index in "post transform to longest side" space point_coords and point_labels are tensors which will be passed to inference session. when running inference with single box everything works fine and generated masks are good, trying to add at least 1 additional box leads to masks being mostly black with some white artifacts in it.

Do ONNX decoder supports multi-box input at the moment? If it does how do I need to put labels on boxes properly?

shubhamtyagii commented 10 months ago

Yes, you can give multiple input bboxes to Onnx mask decoder. You can do it in two ways:

  1. Give the bboxes in single batch, you can easily do it with by making your point tensor in shape [1, 2N, 2]. Here N is the number of points. As for labels, you'll have to create a Tensor of shape [1, 2N], so for a single bbox the label tensor will have [2.,3.], 2 is label for top left corner of bbox and 3. is label for bottom right corner of bbox. The inference works with this approach but results are unexpected, I don't think it is the right way.
  2. You can send the bboxes in as a Batch. For this to use you'll have to make point_coords and point_labels batchable to do that use below dictionary here dynamic_axes = { "point_coords": {0: "batch_size",1: "num_points"}, "point_labels": {0: "batch_size",1: "num_points"} } . Then you can pass each bbox and label as different inputs in a batch. This works like a charm.