xinghaochen / TinySAM

Official PyTorch implementation of "TinySAM: Pushing the Envelope for Efficient Segment Anything Model"
Apache License 2.0
403 stars 23 forks source link

Benchmarking Batch Inference #7

Closed merveenoyan closed 11 months ago

merveenoyan commented 11 months ago

Hello 🙌🏼 I'm running some benchmarks on TinySAM, trying to benchmark batch inference. However, I hit a wall during batch inference. All inputs are torch tensors with the shapes expected by docstrings in inference code:

point_prompt.shape # torch.Size([4, 1, 2]) BXNX2
input_label.shape # torch.Size([4, 1]) BXN
batched_image.shape # torch.Size([4, 3, 1024, 1024]), BCHW
predictor.set_torch_image(batched_image, original_image_size=batched_image[0, 0, :, :].shape) # goes well

# this fails
with torch.no_grad():
        _, _, _ = predictor.predict_torch(
        point_coords=point_prompt,
        point_labels=input_label)

I don't really have a lot of time to debug this as I already did couple of steps, I feel like I'm missing a step, can you let me know if so? I can post a full trace if you want but I really feel like I'm missing a step and hence it errors out.

Gaffey commented 11 months ago

Hi merveenoyan, Thanks for your attempt to benchmark TinySAM. The original interface of SAM does not support batch inference on multiple images, and we‘ve followed this design. From the note of set_torch_image and predict_torch, we can find that,

def set_torch_image(
        self,
        transformed_image: torch.Tensor,
        original_image_size: Tuple[int, ...],
    ) -> None:
"""
Arguments:
          transformed_image (torch.Tensor): The input image, with shape
            1x3xHxW, which has been transformed with ResizeLongestSide.
"""

in which the shape of image is set as 1x3xHxW. And in predict_torch

def predict_torch(
        self,
        point_coords: Optional[torch.Tensor],
        point_labels: Optional[torch.Tensor],
        boxes: Optional[torch.Tensor] = None,
        mask_input: Optional[torch.Tensor] = None,
        return_logits: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Predict masks for the given input prompts, using the currently set image.
        Input prompts are batched torch tensors and are expected to already be
        transformed to the input frame using ResizeLongestSide.

        Arguments:
          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels.
          point_labels (torch.Tensor or None): A BxN array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.
        """

There is batch dimension for point_coords and point_labels, which means multiple points in one image. So the batch inference of SAM/TinySAM only supports multiple prompts for one image, not for multiple images.

As for benchmark, I think it is possible to eval SAM/TinySAM under the same batch settings since we have the same interface. So there is no necessity to implement batch inference for multiple images.

merveenoyan commented 11 months ago

@Gaffey thanks a lot for the swift response! I'll keep this in mind 🙌🏼