CASIA-IVA-Lab / FastSAM

Fast Segment Anything
GNU Affero General Public License v3.0
7.3k stars 681 forks source link

Working with FastSAM results #208

Open med-tim opened 6 months ago

med-tim commented 6 months ago

Does anyone have advice on how to work with the results to extract points of the mask that is generated? I have searched the docs, comments, and python files but cannot find methods to extract the mask points itself which I can them use for other things.

In the code below, ann is returned as a list object that I cannot extract mask information from.

Here is my current implementation (not shown here is that that I previously calculate the x1, y1, x2, y2 coordinates for an area of the image I would like to create a segment mask for:

            model = FastSAM('FastSAM-x.pt')  

            imgsz_frame = source.shape[0]
            everything_results = model(source, device='cpu', retina_masks=True, imgsz=imgsz_frame, conf=0.1, iou=0.9)

            prompt_process = FastSAMPrompt(source, everything_results, device='cpu')

            ann = prompt_process.box_prompt(bbox=[x1, y1, x2, y2])

            prompt_process.plot(annotations=ann, output='./')
med-tim commented 6 months ago

Found a solution:

You can extract masks from the everything_results or ann by iterating through detected objects as this is from the ultralytics.engine.results.Results class. You can find the information here: https://docs.ultralytics.com/reference/engine/results/#ultralytics.engine.results.Results

example code I got working:

        for e in everything_results:
                masks_FastSAM = e.masks.xy[1]
                binary_mask = np.zeros((source.shape[0], source.shape[1]), dtype=np.uint8)
                contours = [np.array(masks_FastSAM, dtype=np.int32)]

                cv2.fillPoly(binary_mask, contours, 255)
                plt.figure(figsize=(10, 10))  # 255 is the fill color in grayscale (white)
                plt.imshow(binary_mask, cmap='gray')  # Display the image in grayscale
                plt.axis('off')  # Turn off axis numbers and ticks
                plt.show()
baoyi-A commented 4 months ago

Thanks a lot!