facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.06k stars 1.09k forks source link

Support for mask only input #147

Open FrankWJW opened 2 months ago

FrankWJW commented 2 months ago

I went through the notebook and only found instructions for point and bounding box inputs. However, according to the paper, it seems SAM2 also supports mask input. Will an example for using mask-only input be released in the future?

heyoeyo commented 2 months ago

A simple way to get this working is to use one of the masks output by the model itself as an input (I believe this is how the mask input is meant to be used, based on SAMv1).

Here's a simple example that runs the decoder twice, once to get a mask and then a second time using only that mask as an input prompt:

import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# File paths
img_path = "path/to/image.jpg"
model_path = "path/to/model.pth"
model_cfg = "model_config.yaml"

# Load model
sam2_model = build_sam2(model_cfg, model_path)
predictor = SAM2ImagePredictor(sam2_model)

# Encode image
image = Image.open(img_path)
image = np.array(image.convert("RGB"))
predictor.set_image(image)

# Run the model once to generate masks to use as inputs
_, _, ex_logits = predictor.predict(
    point_coords=np.array([[500, 375]]),
    point_labels=np.array([1]),
    multimask_output=True,
)

# Create an example mask input
mask_input = ex_logits[0, :, :]
mask_input = np.expand_dims(mask_input, 0) # Make shape: 1x256x256
print("Mask input shape:", mask_input.shape)
print("Mask min/max", mask_input.min(), mask_input.max())

# Run SAM with mask input only
masks, scores, logits = predictor.predict(mask_input=mask_input, multimask_output=True)

Obviously if you have a mask from somewhere else, you can use that as an input, rather than running the model with a random point prompt like this example. Though if you're not using an output from the model itself, you may need to scale the mask values to be more similar to logits (which tend to have a value range of around -10 to +10).

The mask input generally gives poor quality results for me (compared to points or boxes). It seems to be there to support training more so than as a useful feature on it's own, but maybe there's some cases where it can help.

ronghanghu commented 2 months ago

Hi @FrankWJW, regarding using mask prompts:

OrangeSodahub commented 1 month ago

@ronghanghu Hi, I would like to know that for image segmentation, the mask_input can only be the mask scores output from sam itself? I need to use my own binary mask input, how to do that? Thanks!