IDEA-Research / Grounded-Segment-Anything

Grounded SAM: Marrying Grounding DINO with Segment Anything & Stable Diffusion & Recognize Anything - Automatically Detect , Segment and Generate Anything
https://arxiv.org/abs/2401.14159
Apache License 2.0
14.11k stars 1.31k forks source link

inquiry of code/result difference between SAM and GSAM #494

Open TikaToka opened 2 months ago

TikaToka commented 2 months ago

Hello, thank you for sharing amazing work!

I am trying to adapt GSAM code as an base model, but I have some inquiry.

from transformers import SamModel, SamProcessor

sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda" if torch.cuda.is_available() else "cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

sam_masks = []
for idx in range(preds.shape[0]):
  sam_inputs = sam_processor(image, input_points=[sampled_points[idx]], return_tensors="pt").to(device)

  with torch.no_grad():
      sam_outputs = sam_model(**sam_inputs)
      print(sam_outputs)
      print(sam_outputs.pred_masks.cpu().shape)

  sam_masks.append(sam_processor.image_processor.post_process_masks(
      sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu()
  ))

for this code from SAM, each sam_mask has shape(1,1,3,h,w), total (n, 1, 1, 3, h, w) However, if we use this code from GSAM,

image_pil, im = load_image(rgb_path)
# load model
model = load_model(config_file, grounded_checkpoint, device=device)

caption = generate_caption(image_path, device=device)
# Currently ", " is better for detecting single tags
# while ". " is a little worse in some case
text_prompt = generate_tags(caption, split=split)

boxes_filt, scores, pred_phrases = get_grounding_output(
    model, im, text_prompt, box_threshold, text_threshold, device='cuda'
)

print(boxes_filt, scores, pred_phrases)

# initialize SAM
# if use_sam_hq:
#     print("Initialize SAM-HQ Predictor")
#     predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
# else:
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
im = cv2.imread(rgb_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
predictor.set_image(im)

size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
    boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
    boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
    boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, im.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
    point_coords = None,
    point_labels = None,
    boxes = transformed_boxes.to(device),
    multimask_output = False,
)

each mask's shape in masks is (1, h, w), total (n , 1, h, w)

I just wonder why there is a dimensional gap between SAM and GSAM, and is there a way to get a (1,1,3,w,h)? I think it looks like 'Multimask_output=True', and if it is right, then the code might be:

new = [torch.tensor([[mask]]) for mask in masks.cpu().tolist()]

but I want to make sure of it.

Thank you in advance, and have a nice day!