Open squashking opened 2 months ago
Hi. I am also wondering if multiple boxes in the first frame can be given as a prompt. Have you figured it out?
So far from what I've tested you can either do
I tried doing multiple boxes and multiple points, and it didn't work. Though I could probably do more testing.
As @AlexMcClay mentioned, the default setup of the SAM code (both v2 and v1) only allows for 1 box per prompt. That's because of the reshaping step inside the box embedding function, which converts inputs into a shape of Bx2x2, which doesn't allow for providing 'N' for the number of boxes like there is with point prompts.
You can modify the box embedding function to accept BxNx4 shaped box inputs with something like:
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
# If we get a 3D box input, assume shape: BxNx4
# -> Otherwise assume original shape: Bx4, so add N dimension
if boxes.ndim < 3:
boxes = boxes.unsqueeze(1)
num_boxes = boxes.shape[1]
# Run original box embedding for each of the input boxes
embeddings_list = []
for box_idx in range(num_boxes):
box = boxes[:, box_idx, :]
box = box + 0.5 # Shift to center of pixel
coords = box.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(
coords, self.input_image_size
)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
embeddings_list.append(corner_embedding)
# Combine all boxes on N dimension for output
return torch.cat(embeddings_list, dim=1)
Then it should be possible to pass multiple boxes per prompt, using something like:
box1 = [10,20,30,40]
box2 = [100, 200, 300, 400]
multibox_prompt = torch.tensor([box1, box2]).unsqueeze(0)
# ^^^ Has shape: BxNx4, 1x2x4 in this case
However, both SAMv2 & especially v1 seem to do poorly with multi-box prompts, probably because it's not part of the original model/training. It always seems better to use 1 box, for example, here's 1 box vs 2 boxes: (To be fair, this is a simple shape to box. Maybe there are more complicated shapes that benefit from multi-box prompts?)
As @AlexMcClay mentioned, the default setup of the SAM code (both v2 and v1) only allows for 1 box per prompt. That's because of the reshaping step inside the box embedding function, which converts inputs into a shape of Bx2x2, which doesn't allow for providing 'N' for the number of boxes like there is with point prompts.
You can modify the box embedding function to accept BxNx4 shaped box inputs with something like:
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: # If we get a 3D box input, assume shape: BxNx4 # -> Otherwise assume original shape: Bx4, so add N dimension if boxes.ndim < 3: boxes = boxes.unsqueeze(1) num_boxes = boxes.shape[1] # Run original box embedding for each of the input boxes embeddings_list = [] for box_idx in range(num_boxes): box = boxes[:, box_idx, :] box = box + 0.5 # Shift to center of pixel coords = box.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords( coords, self.input_image_size ) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight embeddings_list.append(corner_embedding) # Combine all boxes on N dimension for output return torch.cat(embeddings_list, dim=1)
Then it should be possible to pass multiple boxes per prompt, using something like:
box1 = [10,20,30,40] box2 = [100, 200, 300, 400] multibox_prompt = torch.tensor([box1, box2]).unsqueeze(0) # ^^^ Has shape: BxNx4, 1x2x4 in this case
However, both SAMv2 & especially v1 seem to do poorly with multi-box prompts, probably because it's not part of the original model/training. It always seems better to use 1 box, for example, here's 1 box vs 2 boxes: (To be fair, this is a simple shape to box. Maybe there are more complicated shapes that benefit from multi-box prompts?)
Thank you very much for your answer. When I checked the source code, I found that for multiple box prompts, (Nx4) was converted to (Nx2x2), where N is the number of boxes. Boxes are converted to concat_points in the _predic() function and then input into self.model.sam_prompt_encoder(). In forward() of class PromptEncoder, since boxes have been converted to points, they will not pass through the def _embed_boxes() function. I still have the result of inputting multiple box prompts, but only randomly splitting one box.
Thank you very much for your answer. When I checked the source code, I found that for multiple box prompts, (Nx4) was converted to (Nx2x2), where N is the number of boxes. Boxes are converted to concat_points in the _predic() function and then input into self.model.sam_prompt_encoder(). In forward() of class PromptEncoder, since boxes have been converted to points, they will not pass through the def _embed_boxes() function. I still have the result of inputting multiple box prompts, but only randomly splitting one box.@heyoeyo
@XUYU0205 You're right, it looks like the v2 code no longer uses the embed_boxes
function!
A similar change to support N
boxes can still be made by adjusting the code that's interpreting the boxes as points. In some ways it's a bit simpler than changing the box embedding function, since it's just a matter of reshaping the boxes & labels to have shapes: Bx(2N)x2 and Bx(2N) respectively:
# Modified version of box embedding to support 'N' boxes
# -> Originally line 393 of sam2_image_predictor.py
if boxes is not None:
# If original box shape (Bx4) is given, force it to Bx1x4
# -> Then we can interpret all inputs as shape: BxNx4
if boxes.ndim == 2:
boxes = boxes.unsqueeze(1)
batch_size, num_boxes = boxes.shape[0:2]
# Reshape coords to Bx(2N)x2 and labels to Bx(2N)
# (original code does the same, but always assumed N=1)
box_coords = boxes.reshape(batch_size, 2*num_boxes, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
box_labels = box_labels.repeat(batch_size, num_boxes)
# ... rest of code is the same as original ...
Hi, as far as I know, with SAM1, we cannot use both multiple boxes and multiple points at the same time to segment one image. Can we do this using SAM2? If so, can anyone please share the code? Appreciate it!