facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.63k stars 5.63k forks source link

how to save a generated image? #81

Open SheeppLi opened 1 year ago

SheeppLi commented 1 year ago

Thank you for such an excellent job! I would like to know how to save the images generated from the demo, and how to train the custom dataset. By the way, I run the code and test an image and get separated mask images instead of one image with all masks, is there any method to obtain the corresponding mask image to the original image.

taatuut commented 1 year ago

A quick & dirty way to do this is to use and altered version of amg.py, see the readme for the command line.

Add import statements (make sure these modules and dependencies are installed).

import numpy as np import matplotlib.pyplot as plt

Add function def show_anns (borrowed from one of the Jupyter notebooks).

def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))

Add new function def write_masks_to_png (again some borrowed code from the readme).

def write_masks_to_png(masks: List[Dict[str, Any]], image, path: str) -> None: plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off')

plt.show()

filename = f"masks.png"
plt.savefig(os.path.join(path, filename))
return

In def main after the line:

write_masks_to_folder(masks, save_base)

Add the following line:

write_masks_to_png(masks, image, save_base)

And run amg.py again.

SheeppLi commented 1 year ago

A quick & dirty way to do this is to use and altered version of amg.py, see the readme for the command line.

Add import statements (make sure these modules and dependencies are installed).

import numpy as np import matplotlib.pyplot as plt

Add function def show_anns (borrowed from one of the Jupyter notebooks).

def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))

Add new function def write_masks_to_png (again some borrowed code from the readme).

def write_masks_to_png(masks: List[Dict[str, Any]], image, path: str) -> None: plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') #plt.show() filename = f"masks.png" plt.savefig(os.path.join(path, filename)) return

In def main after the line:

write_masks_to_folder(masks, save_base)

Add the following line:

write_masks_to_png(masks, image, save_base)

And run amg.py again.

Thank you very much for replying! I tried it but got the original image instead of mask image. It seems that polygons = [] color = [] in def show_anns is not used, for it is gray in pycharm. image

srijithrajeev commented 1 year ago

@SheeppLi, The solution provided by @taatuut works. Just make sure you have the right indentation for both the functions and the for loops. Also, try the demo with the modification on different image.

Additionally, the variables polygons and color are not used in the code above as well.

SheeppLi commented 1 year ago

@SheeppLi, The solution provided by @taatuut works. Just make sure you have the right indentation for both the functions and the for loops. Also, try the demo with the modification on different image.

Additionally, the variables polygons and color are not used in the code above as well.

Yeah I got it! Thank you so much!

YL-yyy commented 1 year ago

Hello, why the mask I got is all white, do you have any color generated?

SheeppLi commented 1 year ago

Hello, why the mask I got is all white, do you have any color generated?

Yes I got colorful mask while the separate masks are white.

Mrhard1999 commented 1 year ago

I got the same result as you, except for the 'mask' after you modified the code.So at present, the generated black image block is a mask, how to get the segmentation results in the website in the visualization?@SheeppLi

Mrhard1999 commented 1 year ago

And can you provide your 'amg.py' code which you have changed ?I want to test .thank you .@SheeppLi

longyangqi commented 1 year ago

A quick & dirty way to do this is to use and altered version of amg.py, see the readme for the command line.

Add import statements (make sure these modules and dependencies are installed).

import numpy as np import matplotlib.pyplot as plt

Add function def show_anns (borrowed from one of the Jupyter notebooks).

def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))

Add new function def write_masks_to_png (again some borrowed code from the readme).

def write_masks_to_png(masks: List[Dict[str, Any]], image, path: str) -> None: plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') #plt.show() filename = f"masks.png" plt.savefig(os.path.join(path, filename)) return

In def main after the line:

write_masks_to_folder(masks, save_base)

Add the following line:

write_masks_to_png(masks, image, save_base)

And run amg.py again.

Saving by plt (imshow and save) is still slow when there are many test pictures.
Could you please provide an example of saving by cv2 instead of plt?

YL-yyy commented 1 year ago

image batch processing https://blog.csdn.net/java_pythons/article/details/130132554?spm=1001.2014.3001.5501

GewelsJI commented 1 year ago

There also has a issue page focusing on speeding-up the visualization.

akashAD98 commented 1 year ago

@taatuut @SheeppLi can you share full the script here? can we get the segmented images & its annotation file ?

SheeppLi commented 1 year ago

I don't know exactly how to get the website result, and you can just add the def code into the amg file, it does work.

SheeppLi commented 1 year ago

@taatuut @SheeppLi can you share full the script here? can we get the segmented images & its annotation file ?

I am still working on the annotation file. And the segment file can be got. it just shows the segment image without annotation.

akashAD98 commented 1 year ago

@SheeppLi https://github.com/akashAD98/YOLOV8_SAM im working on it, I'm facing some issues, but half code part is ready

akashAD98 commented 1 year ago

@taatuut Thanks a lot.can you upload the code in .py file? it would be really great

taatuut commented 1 year ago

ez_amg.py.txt

Adding the Python code in file with txt extension.

akashAD98 commented 1 year ago

https://github.com/facebookresearch/segment-anything/issues/215 thanks , i want to save the mask result in coco/yolo/json format . i have done few works but getting one error which is im trying to figuring out

pinksloyd commented 1 year ago

There is no "def main" in the amg.py anymore?

LukeAI commented 1 year ago

here is complete working code (with indentations and everything!) adapted from what @taatuut provided. should save somebody some time...

#!/usr/bin/env python
from __future__ import annotations
import os
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

def write_masks_to_png(masks: List[Dict[str, Any]], image, path: str) -> None:
    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    #plt.show()
    filename = f"masks.png"
    plt.savefig(os.path.join(path, filename))
    return

# sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
sam.to(device="cuda")
mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=16)

image = cv2.imread('cam_front_top_centre/1686044607746820638.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
write_masks_to_png(masks, image, "segmented")
LukeAI commented 1 year ago

I've created a gist to show how to process a dir of images with SAM and save visualisations of the created masks. advantage is that it doesn't have all that random white boundary / resizing that matplotlib insists on. https://gist.github.com/LukeAI/6af4984c79a7534c9c1330958545367c