facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47k stars 5.56k forks source link

What is the shape of the encoder's boxes? #739

Open raoxinyu4977 opened 5 months ago

raoxinyu4977 commented 5 months ago

I attempted to set the shape of the encoder input boxes as (4, 10, 4), representing (bs, num_boxes, 2 box corners). However, during the operation:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding

The boxes encoder is reshaped to (bsnum_boxes, 2, 2), outputting (bsnum_boxes, 2, 256). However, in the forward function, the sparse matrix sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) has a shape of (10, 0, 256). When concatenating with sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1), the dimensions don't align.

heyoeyo commented 5 months ago

The expected shape for box inputs is: Bx4 And if you're using points & labels, the shape should be BxNx2 and BxN respectively, where B is the batch size and N is the number of points. So as-is, the model doesn't support having multiple box prompts for a single mask in the same way that it supports having multiple points (i.e. there is no 'N' component in the box shape).

In theory you could modify the code to support having 'N' boxes, by generating the corner_embedding output for each of the 'N' (10 in the example you gave) boxes and concatenating them together into a single corner_embedding. I think the result should have a shape of Bx(2N)x(embed_dim) to work with the existing code. Though it's unclear how well this would work since it's not part of the original model behavior/training (but worth trying maybe).

raoxinyu4977 commented 5 months ago

thanks, I know

wu2233 commented 1 week ago

Hello, I've encountered the same issue as you. Could you please share how you resolved it? I would greatly appreciate your assistance. Thank you.

heyoeyo commented 1 week ago

There's some code on the SAMv2 issue board that provides support for having multiple boxes for a single prompt. That code references changes to the newer code base, but the equivalent code for SAMv1 can be found in the prompt_encoder script. Though it seems both SAMv2 and SAMv1 perform poorly when using more than 1 box.

wu2233 commented 1 week ago

Thank you very much! I will try.

wu2233 commented 1 week ago

Hello,I've encountered a new issue: For example, in one training batch, I have 2 images; the first image has 2 bounding boxes, and the second image has 3 bounding boxes. In this case, how should I conduct the training? Can the program understand the correspondence between images and bounding boxes within a batch?

heyoeyo commented 1 week ago

In general, if you have different shaped data, it would need to be processed in separate batches. In this case if you had multiple images with 2 bounding boxes you could batch all of them together and likewise for images with 3 bounding boxes.

Alternatively, the SAM model includes a not a point embedding that can be used to pad the prompts, so you could use this to make the 2-box prompt tensors the same shape as the 3-box prompts.

Each box prompt adds two 'points' to the prompt tensor, so I think to pad a 2-box prompt to match the shape of a 3 box prompt, you'd need to do something like:

# Pad 2-box prompt encoding to match 3-box encoding shape
pad_embed = predictor.model.prompt_encoder.not_a_point_embed
sparse_embeddings = torch.cat([sparse_embeddings, pad_embed, pad_embed], dim=1)

This would require modifying the sparse_embeddings that are generated by the prompted encoder (which normally happens inside the predict function).