facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.35k stars 1.14k forks source link

Can I use multiple boxes and multiple points together for image segmentation? #235

Open squashking opened 2 months ago

squashking commented 2 months ago

Hi, as far as I know, with SAM1, we cannot use both multiple boxes and multiple points at the same time to segment one image. Can we do this using SAM2? If so, can anyone please share the code? Appreciate it!

YangJae96 commented 2 months ago

Hi. I am also wondering if multiple boxes in the first frame can be given as a prompt. Have you figured it out?

AlexMcClay commented 2 months ago

So far from what I've tested you can either do

I tried doing multiple boxes and multiple points, and it didn't work. Though I could probably do more testing.

heyoeyo commented 2 months ago

As @AlexMcClay mentioned, the default setup of the SAM code (both v2 and v1) only allows for 1 box per prompt. That's because of the reshaping step inside the box embedding function, which converts inputs into a shape of Bx2x2, which doesn't allow for providing 'N' for the number of boxes like there is with point prompts.

You can modify the box embedding function to accept BxNx4 shaped box inputs with something like:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

    # If we get a 3D box input, assume shape: BxNx4
    # -> Otherwise assume original shape: Bx4, so add N dimension
    if boxes.ndim < 3:
        boxes = boxes.unsqueeze(1)
    num_boxes = boxes.shape[1]

    # Run original box embedding for each of the input boxes
    embeddings_list = []
    for box_idx in range(num_boxes):
        box = boxes[:, box_idx, :]
        box = box + 0.5  # Shift to center of pixel
        coords = box.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(
            coords, self.input_image_size
        )
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        embeddings_list.append(corner_embedding)

    # Combine all boxes on N dimension for output
    return torch.cat(embeddings_list, dim=1)

Then it should be possible to pass multiple boxes per prompt, using something like:

box1 = [10,20,30,40]
box2 = [100, 200, 300, 400]
multibox_prompt = torch.tensor([box1, box2]).unsqueeze(0)
# ^^^ Has shape: BxNx4, 1x2x4 in this case

However, both SAMv2 & especially v1 seem to do poorly with multi-box prompts, probably because it's not part of the original model/training. It always seems better to use 1 box, for example, here's 1 box vs 2 boxes: multibox_example (To be fair, this is a simple shape to box. Maybe there are more complicated shapes that benefit from multi-box prompts?)

XUYU0205 commented 1 month ago

As @AlexMcClay mentioned, the default setup of the SAM code (both v2 and v1) only allows for 1 box per prompt. That's because of the reshaping step inside the box embedding function, which converts inputs into a shape of Bx2x2, which doesn't allow for providing 'N' for the number of boxes like there is with point prompts.

You can modify the box embedding function to accept BxNx4 shaped box inputs with something like:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

    # If we get a 3D box input, assume shape: BxNx4
    # -> Otherwise assume original shape: Bx4, so add N dimension
    if boxes.ndim < 3:
        boxes = boxes.unsqueeze(1)
    num_boxes = boxes.shape[1]

    # Run original box embedding for each of the input boxes
    embeddings_list = []
    for box_idx in range(num_boxes):
        box = boxes[:, box_idx, :]
        box = box + 0.5  # Shift to center of pixel
        coords = box.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(
            coords, self.input_image_size
        )
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        embeddings_list.append(corner_embedding)

    # Combine all boxes on N dimension for output
    return torch.cat(embeddings_list, dim=1)

Then it should be possible to pass multiple boxes per prompt, using something like:

box1 = [10,20,30,40]
box2 = [100, 200, 300, 400]
multibox_prompt = torch.tensor([box1, box2]).unsqueeze(0)
# ^^^ Has shape: BxNx4, 1x2x4 in this case

However, both SAMv2 & especially v1 seem to do poorly with multi-box prompts, probably because it's not part of the original model/training. It always seems better to use 1 box, for example, here's 1 box vs 2 boxes: multibox_example multibox_example (To be fair, this is a simple shape to box. Maybe there are more complicated shapes that benefit from multi-box prompts?)

Thank you very much for your answer. When I checked the source code, I found that for multiple box prompts, (Nx4) was converted to (Nx2x2), where N is the number of boxes. Boxes are converted to concat_points in the _predic() function and then input into self.model.sam_prompt_encoder(). In forward() of class PromptEncoder, since boxes have been converted to points, they will not pass through the def _embed_boxes() function. I still have the result of inputting multiple box prompts, but only randomly splitting one box.

XUYU0205 commented 1 month ago

Thank you very much for your answer. When I checked the source code, I found that for multiple box prompts, (Nx4) was converted to (Nx2x2), where N is the number of boxes. Boxes are converted to concat_points in the _predic() function and then input into self.model.sam_prompt_encoder(). In forward() of class PromptEncoder, since boxes have been converted to points, they will not pass through the def _embed_boxes() function. I still have the result of inputting multiple box prompts, but only randomly splitting one box.@heyoeyo

heyoeyo commented 1 month ago

@XUYU0205 You're right, it looks like the v2 code no longer uses the embed_boxes function!

A similar change to support N boxes can still be made by adjusting the code that's interpreting the boxes as points. In some ways it's a bit simpler than changing the box embedding function, since it's just a matter of reshaping the boxes & labels to have shapes: Bx(2N)x2 and Bx(2N) respectively:

# Modified version of box embedding to support 'N' boxes
# -> Originally line 393 of sam2_image_predictor.py
if boxes is not None:

  # If original box shape (Bx4) is given, force it to Bx1x4
  # -> Then we can interpret all inputs as shape: BxNx4
  if boxes.ndim == 2:
    boxes = boxes.unsqueeze(1)
  batch_size, num_boxes = boxes.shape[0:2]

  # Reshape coords to Bx(2N)x2 and labels to Bx(2N)
  # (original code does the same, but always assumed N=1)
  box_coords = boxes.reshape(batch_size, 2*num_boxes, 2)
  box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
  box_labels = box_labels.repeat(batch_size, num_boxes)

  # ... rest of code is the same as original ...