IDEA-Research / Grounded-SAM-2

Grounded SAM 2: Ground and Track Anything in Videos with Grounding DINO, Florence-2 and SAM 2
https://arxiv.org/abs/2401.14159
Apache License 2.0
682 stars 48 forks source link

Predict from a video file #25

Open Masrur02 opened 3 weeks ago

Masrur02 commented 3 weeks ago

Hi, When I ran the code python grounded_sam2_local_demo.py the result was good with a prompt text="car. road." grounded_sam2_annotated_image_with_mask

But, when I have modified the code to read images from a video file and keep looping

import cv2
import torch
import numpy as np
import supervision as sv
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
import time
import os

# Environment settings
# Use bfloat16 only where supported

# Build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_predictor = SAM2ImagePredictor(sam2_model)

# Build Grounding DINO model
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
grounding_model = load_model(
    model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", 
    model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth",
    device=device
)

# Setup the input text prompt for Grounding DINO
text = "road. car."
output_dir = "test"
os.makedirs(output_dir, exist_ok=True)

# Capture video
video_path = 'notebooks/videos/indy.mp4'
cap = cv2.VideoCapture(video_path)
frame_num = 0

while cap.isOpened():
    start_time = time.time()
    ret, frame = cap.read()
    if not ret:
        break

    #time.sleep(0.1)

    # Convert the frame to the required format for processing
    image_source, image = load_image(frame)

    sam2_predictor.set_image(image_source)

    boxes, confidences, labels = predict(
        model=grounding_model,
        image=image,
        caption=text,
        box_threshold=0.35,
        text_threshold=0.25
    )

    # Process the box prompt for SAM2
    h, w, _ = frame.shape
    boxes = boxes * torch.Tensor([w, h, w, h])
    input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

    # Enable mixed precision only for the specific block
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        if torch.cuda.get_device_properties(0).major >= 8:
            # Enable tfloat32 for Ampere GPUs
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

        # Perform SAM2 prediction within the mixed precision context
        masks, scores, logits = sam2_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )

    # Post-process the output of the model to get the masks, scores, and logits for visualization
    if masks.ndim == 4:
        masks = masks.squeeze(1)

    confidences = confidences.numpy().tolist()
    class_names = labels
    class_ids = np.array(list(range(len(class_names))))

    labels = [
        f"{class_name} {confidence:.2f}"
        for class_name, confidence
        in zip(class_names, confidences)
    ]

    # Calculate FPS
    end_time = time.time()
    fps = 1 / (end_time - start_time)

    # Visualize image with supervision API
    detections = sv.Detections(
        xyxy=input_boxes,  # (n, 4)
        mask=masks.astype(bool),  # (n, h, w)
        class_id=class_ids
    )

    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)

    label_annotator = sv.LabelAnnotator()
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)

    mask_annotator = sv.MaskAnnotator()
    annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
    mask_image_save_path = os.path.join(output_dir, f"{frame_num:04d}_mask.jpg")

    cv2.imwrite(mask_image_save_path, annotated_frame)
    print(f"FPS for frame {frame_num}: {fps:.2f}")

    frame_num += 1

cap.release()
cv2.destroyAllWindows()

the result has become very bad

0002_mask

What is the reason? Can you please help??

TIA

Masrur02 commented 3 weeks ago

Moreover, is there any way to increase the FPS? TIA

rentainhe commented 3 weeks ago

I have noticed that the confidence score is different when using the same model on these two images, would you like to check if the annotated frame is the same or not in your code.

Masrur02 commented 3 weeks ago

Hi, Ya this is kind of my question. Why the confidence score is different and why the road is even not detected in the second image (though has a higher confidence). Moreover, when I tried the tracking with continuous id example the result was super good (you can see the ego car is also detected as the road, which is very weird, but in the tracking example it is perfectly predicted as a car). But in the tracking example it seems all the images from the video are stored first. Is there a way to extract images one by one and use the tracking example. Mainly I am planning to use the Grounded-SAM2 for tracking cars and roads from the images captured from a camera since your work seems very promising and interesting.

TIA