facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.14k stars 1.1k forks source link

Enhance `image_predictor_example.ipynb` with Interactive Point Addition Using Matplotlib #421

Open future-158 opened 1 week ago

future-158 commented 1 week ago

Issue: The current image_predictor_example.ipynb provides a good example of using the image predictor. However, it lacks interactivity, which can enhance user experience and facilitate experimentation.

Proposed Enhancement: Introduce an interactive feature that allows users to add positive and negative points by clicking on the image:

This can be achieved with a simple Matplotlib-based script consisting of less than 110 lines of code, eliminating the need for complex third-party tools for simple testing.

Benefits:

Implementation: I have made a concise script that demonstrates this functionality. You can view the complete code in the following Gist.

Example Code Snippet:

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
import requests

def load_image(url: str) -> Image.Image:
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.83 Safari/537.36"
    }    
    image = Image.open(requests.get(url, stream=True, headers=headers).raw)
    return image

%matplotlib widget

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
# predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

# load example image
url = "https://images.pexels.com/photos/529782/pexels-photo-529782.jpeg?auto=compress&cs=tinysrgb&w=800"
base_img = load_image(url)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(base_img)

img = np.array(base_img)
fig, ax = plt.subplots()
im = ax.imshow(img)

# Remove ticks
ax.set_xticks([])
ax.set_yticks([])

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])

# Remove axis labels
ax.set_xlabel("")
ax.set_ylabel("")

plt.tight_layout()

positive_points = []
negative_points = []
mask = None

def inference() -> np.ndarray:
    global mask
    point_coords = [*positive_points, *negative_points]
    point_labels = [1] * len(positive_points) + [0] * len(negative_points)

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        masks, _, _ = predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=False,
        )

    mask = masks[0] > 0
    blended = Image.blend(
        Image.new("RGB", base_img.size, (0, 0, 255)), base_img, alpha=0.5
    )
    composited = Image.composite(blended, base_img, Image.fromarray(mask > 0)).convert(
        "RGB"
    )

    return np.array(composited)

def on_click(event):
    """
    Event handler for mouse click events on the plot.
    Parameters:
    - event: The mouse event.
    """
    if event.inaxes:
        x, y = event.xdata, event.ydata

        if event.button == 1:  # Left mouse button
            positive_points.append((x, y))
        elif event.button == 3:  # Right mouse button
            negative_points.append((x, y))

        new_rgb = inference()
        for p in positive_points:
            x, y = p
            x, y = int(x), int(y)
            new_rgb[y - 5 : y + 5, x - 5 : x + 5] = [0, 255, 0]

        for p in negative_points:
            x, y = p
            x, y = int(x), int(y)
            new_rgb[y - 5 : y + 5, x - 5 : x + 5] = [255, 0, 0]
        im.set_data(new_rgb)
        fig.canvas.draw_idle()

cid = fig.canvas.mpl_connect("button_press_event", on_click)
plt.show()

Conclusion: when i first ran demo with my image, i draw image with plotly first (cause it show mouse point coordinates) and manually update postive points and negative points one by one.
i think adding this interactive example can make the image_predictor_example.ipynb more engaging. I'm happy to contribute this example to the repository or provide further assistance if needed.