yformer / EfficientSAM

EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
Apache License 2.0
2.12k stars 151 forks source link

the efficientSAM model do not support input bounding box-prompt? #40

Closed duxuan11 closed 2 months ago

duxuan11 commented 9 months ago

the efficientSAM model do not support input multi-bounding box?

yjh0410 commented 7 months ago

@duxuan11 Dear friend, although the official code does not provide an example of bbox-prompt, referring to the code of the SAM project, we only need to convert bbox to point format and specify labels as 2 (top-left) and 3 (bottom-right). Below, I provide an example, in which I designed two bboxes (xyxy format) to require EfficientSAM to segment two objects. You can refer to this code to implement your own needs (please do not paste and copy directly, because I slightly modified the file structure of the project)...

import cv2
from torchvision import transforms
import torch
import numpy as np
import argparse
import os

from models.build_efficient_sam import efficient_sam_model_registry

parser = argparse.ArgumentParser(description=("Runs automatic mask generation on an input image or directory of images, "
                                              "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
                                              "as well as pycocotools if saving in RLE format."),
                                              )

parser.add_argument("--input", type=str, required=True,
                    help="Path to either a single input image or folder of images.",
                    )

parser.add_argument("--output", type=str, required=True,
                    help=("Path to the directory where masks will be output. Output will be either a folder "
                          "of PNGs per image or a single json with COCO-style masks."),
                    )

parser.add_argument("--model-type", type=str, required=True,
                    help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
                    )

parser.add_argument("--checkpoint", type=str, required=True,
                    help="The path to the SAM checkpoint to use for mask generation.",
                    )

parser.add_argument("--device", type=str, default="cuda",
                    help="The device to run generation on.")

parser.add_argument("--convert-to-rle", action="store_true",
                    help=("Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
                          "Requires pycocotools."),
                    )

parser.add_argument("--show", action="store_true",
                    help=("To show the segmentation results on the input image."),
                    )

def main(args):
    # Build the EfficientSAM model.
    model = efficient_sam_model_registry[args.model_type](checkpoint=args.checkpoint)

    # load an image
    sample_image_np = cv2.imread("data/images/ex1.jpg")
    sample_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_BGR2RGB)
    sample_image_tensor = transforms.ToTensor()(sample_image_np)

    # bboxes of the sample
    bboxes = [[ 85.7600, 196.6265, 469.7600, 543.6144],
              [236.8000,  82.8916, 325.1200, 441.4458]]

    # convert the bboxes into the point prompts
    num_queries = len(bboxes)
    input_points = torch.as_tensor(bboxes).unsqueeze(0)      # [bs, num_queries, 4], bs = 1
    input_points = input_points.view(-1, num_queries, 2, 2)  # [bs, num_queries, num_pts, 2]
    input_labels = torch.tensor([2, 3])  # top-left, bottom-right
    input_labels = input_labels[None, None].repeat(1, num_queries, 1) # [bs, num_queries, num_pts]

    print('Running inference using ',)
    predicted_logits, predicted_iou = model(
        sample_image_tensor[None, ...],
        input_points,
        input_labels,
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    # [bs, num_queries, num_candidate_masks, img_h, img_w]
    predicted_logits = torch.take_along_dim(
        predicted_logits, sorted_ids[..., None, None], dim=2
    )
    masks = torch.ge(predicted_logits, 0).cpu().detach().numpy()
    masks = masks[0, :, 0, :, :]  # [num_queries, img_h, img_w]

    if args.show:
        masked_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_RGB2BGR)
        for i in range(num_queries):
            mask = masks[i]
            color = [(np.random.randint(255), np.random.randint(255), np.random.randint(255))]
            # [H, W] -> [H, W, 1]         
            mask = np.repeat(mask[..., None], 3, axis=-1)
            mask_rgb = mask * color * 0.6
            inv_alph_mask = (1 - mask * 0.6)
            masked_image_np = (masked_image_np * inv_alph_mask +  mask_rgb).astype(np.uint8)
        cv2.imshow("masked image", masked_image_np)
        cv2.waitKey(0)

    # save the results
    os.makedirs("outputs/efficient_sam/", exist_ok=True)
    masked_image_np = masked_image_np.copy().astype(np.uint8)
    cv2.imwrite("outputs/efficient_sam/result.png", masked_image_np)

if __name__ == "__main__":
    args = parser.parse_args()
    np.random.seed(12)

    main(args)