facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.87k stars 5.66k forks source link

Increase segmentation areas? #681

Open Rich2020 opened 9 months ago

Rich2020 commented 9 months ago

I want to automatically generate masks (regions that should be ignored by another model) for my images. Segment-anything works wonderfully, but the masks are too tight-fitting. Is it possible to expand a given mask in all directions by some n pixels, and then clip the expended pixels that fall outside of the original image?

heyoeyo commented 9 months ago

You can achieve something like this using 'morphological filtering', specifically 'dilation', which is available using opencv.

The code would be something like:

import cv2
import numpy as np

# ... Set up SAM model

# Call SAM to get masks
masks, _, _ = predictor.predict(...)

# Grab example mask in uint8 format (assuming we used multimask)
# (if you're only generating 1 mask, you don't need to index with [0])
mask_uint8 = np.uint8(masks[0]) * 255

# Expand (dilate) the mask
ksize = (15,15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize)
expanded_mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_DILATE, kernel)

# If you want the mask back in True/False format
expanded_mask_bool = expanded_mask_uint8 > 0
Rich2020 commented 9 months ago

@heyoeyo Thank you very much! I had actually tried using:

mask = np.random.choice([False, True], size=(640, 512))
mask_uint8 = mask.astype(np.uint8) * 255
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) 
dilated_mask = cv2.dilate(mask_uint8, kernel)

But for some reason it was dilating the entire mask (causing an issue with the edges of the image). It must be something to do with cv2.dilate(mask_uint8, kernel).

Anyway, your approach worked, so thank you again!