SysCV / sam-hq

Segment Anything in High Quality [NeurIPS 2023]
https://arxiv.org/abs/2306.01567
Apache License 2.0
3.73k stars 224 forks source link

"Automatic mask generation" is possible? #9

Open mattyamonaca opened 1 year ago

mattyamonaca commented 1 year ago

Is there a function to generate masks fully automatically, like the notebook in the example? Also, do you have any plans to add such a function?

https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb

benji2264 commented 1 year ago

Hi @mattyamonaca, you can already do it with exactly the same code as SAM :)

from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

# Read image
image = cv2.imread("dog.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load HQ-SAM
sam_checkpoint = "sam_hq_vit_h.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Generate segmentation
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

This gives you a list of masks. Each mask is a dictionary for which the "segmentation" key gives you a boolean segmentation mask. You can visualize them with the show_anns() function provided by SAM:

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

plt.figure(figsize=(12,7))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 
mattyamonaca commented 1 year ago

Thanks!

SuroshAhmadZobair commented 1 year ago

Hi Thanks for the scrip @benji2264 . When i run the code in jupyter, i get the following error:

TypeError: 'SamAutomaticMaskGenerator' object is not callable any insight? cheers!

benji2264 commented 1 year ago

Hi @SuroshAhmadZobair, thank you for pointing this out, it should work now :) I had forgot to call the generate() function

ymq2017 commented 1 year ago

Thanks! We also provide a notebook now.