Open FrankWJW opened 2 months ago
A simple way to get this working is to use one of the masks output by the model itself as an input (I believe this is how the mask input is meant to be used, based on SAMv1).
Here's a simple example that runs the decoder twice, once to get a mask and then a second time using only that mask as an input prompt:
import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# File paths
img_path = "path/to/image.jpg"
model_path = "path/to/model.pth"
model_cfg = "model_config.yaml"
# Load model
sam2_model = build_sam2(model_cfg, model_path)
predictor = SAM2ImagePredictor(sam2_model)
# Encode image
image = Image.open(img_path)
image = np.array(image.convert("RGB"))
predictor.set_image(image)
# Run the model once to generate masks to use as inputs
_, _, ex_logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True,
)
# Create an example mask input
mask_input = ex_logits[0, :, :]
mask_input = np.expand_dims(mask_input, 0) # Make shape: 1x256x256
print("Mask input shape:", mask_input.shape)
print("Mask min/max", mask_input.min(), mask_input.max())
# Run SAM with mask input only
masks, scores, logits = predictor.predict(mask_input=mask_input, multimask_output=True)
Obviously if you have a mask from somewhere else, you can use that as an input, rather than running the model with a random point prompt like this example. Though if you're not using an output from the model itself, you may need to scale the mask values to be more similar to logits (which tend to have a value range of around -10 to +10).
The mask input generally gives poor quality results for me (compared to points or boxes). It seems to be there to support training more so than as a useful feature on it's own, but maybe there's some cases where it can help.
Hi @FrankWJW, regarding using mask prompts:
mask_input
like @heyoeyo mentioned aboveadd_new_mask
for it in the video predictor class SAM2VideoPredictor
@ronghanghu Hi, I would like to know that for image segmentation, the mask_input
can only be the mask scores output from sam itself? I need to use my own binary mask input, how to do that? Thanks!
I went through the notebook and only found instructions for point and bounding box inputs. However, according to the paper, it seems SAM2 also supports mask input. Will an example for using mask-only input be released in the future?