facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.49k stars 5.62k forks source link

Multiple boxes and points as input #620

Open shrutichakraborty opened 11 months ago

shrutichakraborty commented 11 months ago

Hi all!

I'd like to use multiple boxes and multiple points as input to predict the masks. IHowever, I'm getting a shape error when I try that. The code I have been trying is :

input_box = torch.tensor(input_box)
        input_box = predictor.transform.apply_boxes_torch(input_box, image.shape[:2])
        if input_point is not None : 
            input_point = torch.as_tensor(input_point, dtype=torch.float)
            input_label = torch.as_tensor(input_label, dtype=torch.int)

            input_point = predictor.transform.apply_coords_torch(input_point,image.shape[:2])
            # input_label = predictor.transform.apply_coords_torch(input_label,image.shape[:2])
            print("labels_torch:",input_label.shape)
            input_point, input_label = input_point[None, :, :], input_label[None, :]
            print("coords_torch:",input_point.shape)
            print("labels_torch:",input_label.shape)

        masks, _, _ = predictor.predict_torch(
        point_coords=input_point,
        point_labels=input_label,
        boxes=input_box,
        multimask_output=False,
        )

The error message I get is :

image

I am using the predict_torch method, so I had a look at the predictor.py file, which requires that point_labels is a BxN torch tensor and point_coordinates is aBxNx2```. Here, I am not sure what B is, but N I assume is the number of points clicked. As I am usingpredict_torchmethod directly without first using thepredictor.predict method, I also ensured to convert input_labels and input_points to tensors and the shapes that I get are: torch.Size([1, 5]) and ([1, 5, 2]) respectively .

Can someone help me out? Thanks!

heyoeyo commented 11 months ago

As-is, I don't think the code supports having more than 1 box for a single input prompt. Technically the model can handle it, but it requires modifying the original code base.

The main sticking point is how the box coordinates are reshaped in the prompt encoder, which assumes you're providing box inputs of shape: Bx4. If you instead provide 'N' boxes as an input (i.e. shape: Nx4), then this gets misinterpreted as Bx4, which eventually causes a mismatch in the 'B' dimension later on when the box points get appended to the other points you provide, which don't have a matching 'B' dimension (this is the error message you're getting about not being able to concatenate the embeddings, since you have '1' batch of points/labels, but '2' boxes which is interpreted as a 2 batches of 1 box).

It's easy enough to fix the problem by modifying the return statement of the box embedding function:

return corner_embedding.reshape(1, -1, self.embed_dim)

This just reshapes the embedded box points to a shape of: 1xNx256 (where the first 1 is the B dimension, N is the number of boxes and 256 is the size of the embedding vectors). However, this assumes that a single batch of boxes are provided as a shape of: Nx4 (where N is the number of boxes). Technically this 'breaks' the original assumptions about how the model functions, so other code (like the example notebooks) will probably fail if you make this modification! It may also give non-sense results since the model may not have been trained to make sense of multi-box prompts.

shrutichakraborty commented 11 months ago

Thanks for getting back! It seems that when I use multiple boxes but I set points/labels as None that works, also the case where I use 1 box and multiple points/labels works too. BUT, when I try to use mutltiple boxes AND mutiple points that fails, which is bizzare because I am just sort of merging the capabilities of the two previous cases. It seems like the error comes from shape of the input_point and input_label... any ideas?

heyoeyo commented 11 months ago

multiple boxes but I set points/labels as None that work

This works because if you don't provide any points, then the sparse_embeddings will have a batch dimension based on the boxes only, and there won't be any points to concatenate, and therefore no batch mismatch. However, you should find that you don't get the right 'shape' for the masks on the model output. Instead of getting 1 mask for the combined boxes as a single input prompt, you'll get 1 mask for every box that you provide. If you compare this to providing multiple points as a single prompt input (like the [1,5,2] example you mentioned), you should get only one mask as an output, not 5 (i.e. 1 for for each point).

use 1 box and multiple points/labels works too

This works because the code assumes you're providing boxes of the shape Bx4. If you only use a batch of 1 for the points, and also use only 1 box (i.e. input of shape: 1x4), then the batch dimensions will match (i.e. B = 1) and the concatenation will work without error.

heyoeyo commented 11 months ago

I should add, if your goal is to process the boxes independently but share the points between them, then that is possible. You just have to duplicate the points/labels to match the number of boxes before passing them as inputs to the predictor, something like:

num_boxes = 2
input_point = input_point.repeat((num_boxes,1,1))
input_label = input_label.repeat((num_boxes,1))

This way should work without the error, but it's not a single prompt anymore so you're gonna get multiple masks (one for each box).

shrutichakraborty commented 11 months ago

Thanks a lot! I will try that. Is there a way to specify negative regions (background) using boxes?

heyoeyo commented 11 months ago

Is there a way to specify negative regions (background) using boxes?

No not directly. The model has 5 different ways of encoding points - one special not a point type, and then 4 others: background points, foreground points, top-left box corner points and bottom-right box corner points. In theory, the 'background point' embedding could be added to the top-left & bottom-right embeddings to make something like a 'background top-left/bottom-right corner' point embedding, but this would require modifying the original code, and may not make any sense since the model probably wasn't trained to support this.

Another way to specify negative regions would be to use a mask prompt (with the negative regions being excluded by the provided mask), however the mask prompt input seems very finicky. If you do want to try that approach, there are a few other issues describing attempts at how to use it properly: #360, #169, #242