IDEA-Research / Grounded-Segment-Anything

Grounded SAM: Marrying Grounding DINO with Segment Anything & Stable Diffusion & Recognize Anything - Automatically Detect , Segment and Generate Anything
https://arxiv.org/abs/2401.14159
Apache License 2.0
15.14k stars 1.4k forks source link

Some images are segmented perfectly, whilst some are completely wrong. #313

Closed stefanjaspers closed 1 year ago

stefanjaspers commented 1 year ago

This issue is on SAM's side, but since the image goes through GroundingDINO first, I figured it would be best here.

I'm building a book recognition app that can take pictures of books or select a picture from a photo gallery. For the image segmentation part, I copied most of the code in this repository, except I'm providing a base64 decoded input since it's arrived in base64 encoded format from a Flutter app.

import cv2
import numpy as np
import torch

# Segment Anything.
from segment_anything import build_sam, SamPredictor

class SegmentAnythingService:
    def __init__(self) -> None:
        pass

    def get_sam_output(self, sam_checkpoint, image, image_pil, boxes_filt):
        predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device="cpu"))

        image_bytes = image.getvalue()

        nparr = np.frombuffer(image_bytes, np.uint8)

        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

        predictor.set_image(image)

        size = image_pil.size
        H, W = size[1], size[0]
        for i in range(boxes_filt.size(0)):
            boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
            boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
            boxes_filt[i][2:] += boxes_filt[i][:2]

        boxes_filt = boxes_filt.cpu()
        transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device="cpu")

        masks, _, _ = predictor.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = transformed_boxes.to(device="cpu"),
            multimask_output = False,
        )

        return masks

For debugging purposes, I save all the generated masks to check how they look. For example, let's take this image:

IMG_4275

This goes through SAM perfectly and gives back the following masks:

image_0 image_1 image_2

However, if I take the following picture:

IMG_4274

This gets completely messed up and I get corrupted masks that represent nothing:

image_0 image_1 image_2

Both pictures were taken using an iPhone 13 Pro Max, but that's probably irrelevant. I've been dealing with this issue for a week now. I would like for every image to be segmented properly, but it just won't. It's random for every image - some segment perfectly, some don't, any I don't see why.

Does anyone have tips and/or solutions?

rentainhe commented 1 year ago

I've uploaded your demo image on segment-anything online website demo, it seems with the proper box prompt, it can segment the book perfect on your worse case:

image

Would you like to provide your box prompt visualization for us?

stefanjaspers commented 1 year ago

Thank you for your reply. I tried the demo too and it recognized all books perfectly as well.

With providing box prompt visualization, do you mean using the boxes_filt that gets returned from get_grounding_output and use that data to draw the bounding boxes on the original image? I'm kind of new to this. Thanks in advance!

rentainhe commented 1 year ago

Thank you for your reply. I tried the demo too and it recognized all books perfectly as well.

With providing box prompt visualization, do you mean using the boxes_filt that gets returned from get_grounding_output and use that data to draw the bounding boxes on the original image? I'm kind of new to this. Thanks in advance!

Yes, I think there're some box prompts not right, you should visualize it to see what's the problem

stefanjaspers commented 1 year ago

I wrote the code after get_grounding_input with some helpt of GPT-4, and the object detection works fine.

The code:

        # Run Grounding DINO model.
        boxes_filt, pred_phrases = grounding_dino_service.get_grounding_output(
            model,
            image_tensor,
            config["text_prompt"],
            config["box_threshold"],
            config["text_threshold"],
            config["device"],
        )

        # Convert tensor to numpy if it's not already
        if isinstance(boxes_filt, torch.Tensor):
            boxes_filt = boxes_filt.cpu().numpy()

        fig, ax = plt.subplots(1)
        ax.imshow(image_pil)

        # Image dimensions
        width, height = image_pil.size

        for i, box in enumerate(boxes_filt):
            # Convert from center_x, center_y, w, h to x1, y1, x2, y2 and scale
            center_x, center_y, w, h = box
            x1 = (center_x - w / 2) * width
            y1 = (center_y - h / 2) * height
            x2 = (center_x + w / 2) * width
            y2 = (center_y + h / 2) * height

            rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

            # Add the predicted phrase as a label
            ax.text(x1, y1, pred_phrases[i], color='r')

        plt.show()
image

It did rotate the image 90 degrees though, but I doubt that could be part of the problem? So it seems GroundingDINO isn't the problem.

rentainhe commented 1 year ago

I wrote the code after get_grounding_input with some helpt of GPT-4, and the object detection works fine.

The code:

        # Run Grounding DINO model.
        boxes_filt, pred_phrases = grounding_dino_service.get_grounding_output(
            model,
            image_tensor,
            config["text_prompt"],
            config["box_threshold"],
            config["text_threshold"],
            config["device"],
        )

        # Convert tensor to numpy if it's not already
        if isinstance(boxes_filt, torch.Tensor):
            boxes_filt = boxes_filt.cpu().numpy()

        fig, ax = plt.subplots(1)
        ax.imshow(image_pil)

        # Image dimensions
        width, height = image_pil.size

        for i, box in enumerate(boxes_filt):
            # Convert from center_x, center_y, w, h to x1, y1, x2, y2 and scale
            center_x, center_y, w, h = box
            x1 = (center_x - w / 2) * width
            y1 = (center_y - h / 2) * height
            x2 = (center_x + w / 2) * width
            y2 = (center_y + h / 2) * height

            rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

            # Add the predicted phrase as a label
            ax.text(x1, y1, pred_phrases[i], color='r')

        plt.show()
image

It did rotate the image 90 degrees though, but I doubt that could be part of the problem? So it seems GroundingDINO isn't the problem.

I have no idea why this happened, but I think the box prompt is correct~ you can send it to SAM for generating masks

stefanjaspers commented 1 year ago

I think I solved it, but as I mentioned before I found it weird the image got rotated 90 degrees. I updated the load_image function so the image always keeps its original orientation. Now, the image doesn't rotate 90 degrees anymore and guess what? SAM works for this image. I'm going to do some more testing but it looks like this weird "bug" is fixed for me for now.

Thank you for thinking along with me! 😄