facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.4k stars 5.61k forks source link

Improve Segmentation using Bounding Boxes #674

Closed foreignsand closed 7 months ago

foreignsand commented 9 months ago

Is there a way to improve SamPredictor segmentation when using bounding boxes?

I have something I want to segment—fungal colonies that are growing into each other in petri dish—and the predictive model doesn't do the greatest job of segmenting it:

20231212_085406

This is relatively easy to do with a petri dish with separated colonies like this:

20231031_ERG24_C5_8

And the following script successfully segments the above petri dish with separated colonies:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 20:45:01 2024

@author: lemurbear
"""

# import modules
import cv2
import numpy as np
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
from segment_anything \
    import sam_model_registry, SamPredictor
import supervision as sv

# use code if available
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# import sam model
sam_checkpoint = "/Users/lemurbear/Downloads/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_predictor = SamPredictor(sam)

# import image and set image with mask_predictor
predictive_img_path = "/Users/lemurbear/Downloads/20231031_ERG24_C5_8.jpg"
image_bgr = cv2.imread(predictive_img_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

mask_predictor.set_image(image_rgb)

# define bounding boxes
box_01 = {'x': 370, 'y': 210, 'width': 175, 'height': 160, 'label': ''}
box_02 = {'x': 370, 'y': 317, 'width': 175, 'height': 160, 'label': ''}
box_03 = {'x': 370, 'y': 480, 'width': 175, 'height': 160, 'label': ''}
box_04 = {'x': 370, 'y': 560, 'width': 175, 'height': 160, 'label': ''}
box_05 = {'x': 370, 'y': 670, 'width': 175, 'height': 160, 'label': ''}
box_06 = {'x': 370, 'y': 811, 'width': 175, 'height': 160, 'label': ''}
box_07 = {'x': 370, 'y': 940, 'width': 175, 'height': 160, 'label': ''}
box_08 = {'x': 370, 'y': 1070, 'width': 175, 'height': 160, 'label': ''}
box_dict = {'box_01': box_01, 'box_02': box_02, 'box_03': box_03, 
            'box_04': box_04, 'box_05': box_05, 'box_06': box_06, 
            'box_07': box_07, 'box_08': box_08} 

# assign bounding boxes to 'boxes' array
box_num = 0
boxes = {}
for this_box in box_dict:
    box_array = np.array([
        box_dict[this_box]['x'],
        box_dict[this_box]['y'],
        box_dict[this_box]['x'] + box_dict[this_box]['width'],
        box_dict[this_box]['y'] + box_dict[this_box]['height']
    ])

    boxes[box_num] = box_array
    box_num = box_num + 1
    print('colony_' + str('{0:0=2d}'.format(box_num) + ' :' + str(box_dict[this_box])))

print(boxes)

box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX)

# magics
box_num = 1
petri_areas = {}
for this_box in boxes:
    box_name = 'box_' + str('{0:0=2d}'.format(box_num))
    mask_name = 'masks_' + str('{0:0=2d}'.format(box_num))
    score_name = 'scores_' + str('{0:0=2d}'.format(box_num))
    logit_name = 'logit_' + str('{0:0=2d}'.format(box_num))
    image_name = 'segmented_image_' + str('{0:0=2d}'.format(box_num) + '.jpg')
    print(box_name, mask_name, score_name, logit_name, boxes[this_box])
    box_num = box_num + 1
    masks_this, scores_this, logits_this = mask_predictor.predict(
        box=boxes[this_box],
        multimask_output=True
    )

    detections_this = sv.Detections(
        xyxy=sv.mask_to_xyxy(masks=masks_this),
        mask=masks_this
    )
    detections_this = detections_this[detections_this.area == np.max(detections_this.area)]
    area_this = str(round((detections_this.area[0] / (28.346*5)**2), 3))
    print("Colony area: ", area_this)

    source_image = box_annotator.annotate(scene=image_bgr.copy(), 
                                          detections=detections_this, skip_label=True)
    segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections_this)

    sv.plot_images_grid(
        images=[source_image, segmented_image],
        grid_size=(1, 2),
        titles=['source image', 'segmented image']
    )

    cv2.imwrite(image_name, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))

    sv.plot_images_grid(
        images=masks_this,
        grid_size=(1, 4),
        size=(16, 4)
    )

segmented_image_01 segmented_image_02 segmented_image_03 segmented_image_04 segmented_image_05 segmented_image_06 segmented_image_07 segmented_image_08

If I change the box locations, it does an okay job with the first petri dish, but it is still patchy.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 20:45:01 2024

@author: lemurbear
"""
import cv2
import numpy as np
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
from segment_anything \
    import sam_model_registry, SamPredictor
import supervision as sv

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

sam_checkpoint = "/Users/lemurbear/Downloads/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictive_img_path = "/Users/lemurbear/Downloads/20231212_085406.jpg"

mask_predictor = SamPredictor(sam)

image_bgr = cv2.imread(predictive_img_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

mask_predictor.set_image(image_rgb)

############ writing for loop(s) ##############

box_01 = {'x': 360, 'y': 200, 'width': 230, 'height': 170, 'label': ''}
box_02 = {'x': 370, 'y': 317, 'width': 185, 'height': 160, 'label': ''}
box_03 = {'x': 340, 'y': 480, 'width': 230, 'height': 160, 'label': ''}
box_04 = {'x': 330, 'y': 570, 'width': 220, 'height': 160, 'label': ''}
box_05 = {'x': 330, 'y': 670, 'width': 210, 'height': 160, 'label': ''}
box_06 = {'x': 310, 'y': 811, 'width': 270, 'height': 160, 'label': ''}
box_07 = {'x': 300, 'y': 950, 'width': 280, 'height': 160, 'label': ''}
box_08 = {'x': 325, 'y': 1075, 'width': 220, 'height': 200, 'label': ''}
box_dict = {'box_01': box_01, 'box_02': box_02, 'box_03': box_03, 
            'box_04': box_04, 'box_05': box_05, 'box_06': box_06, 
            'box_07': box_07, 'box_08': box_08} 

# do a for loop for assigning box_01, box_02... to boxes array

box_num = 0
boxes = {}
for this_box in box_dict:
    box_array = np.array([
        box_dict[this_box]['x'],
        box_dict[this_box]['y'],
        box_dict[this_box]['x'] + box_dict[this_box]['width'],
        box_dict[this_box]['y'] + box_dict[this_box]['height']
    ])

    boxes[box_num] = box_array

    box_num = box_num + 1

    print('colony_' + str('{0:0=2d}'.format(box_num) + ' :' + str(box_dict[this_box])))

print(boxes)

box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX)

box_num = 1
petri_areas = {}
for this_box in boxes:
    box_name = 'box_' + str('{0:0=2d}'.format(box_num))
    mask_name = 'masks_' + str('{0:0=2d}'.format(box_num))
    score_name = 'scores_' + str('{0:0=2d}'.format(box_num))
    logit_name = 'logit_' + str('{0:0=2d}'.format(box_num))
    image_name = 'segmented_image_' + str('{0:0=2d}'.format(box_num) + '.jpg')
    print(box_name, mask_name, score_name, logit_name, boxes[this_box])
    box_num = box_num + 1
    masks_this, scores_this, logits_this = mask_predictor.predict(
        box=boxes[this_box],
        multimask_output=True
    )

    detections_this = sv.Detections(
        xyxy=sv.mask_to_xyxy(masks=masks_this),
        mask=masks_this
    )
    detections_this = detections_this[detections_this.area == np.max(detections_this.area)]
    area_this = str(round((detections_this.area[0] / (28.346*5)**2), 3))
    print("Colony area: ", area_this)

 #   petri_areas = 

    source_image = box_annotator.annotate(scene=image_bgr.copy(), 
                                          detections=detections_this, skip_label=True)
    segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections_this)

    sv.plot_images_grid(
        images=[source_image, segmented_image],
        grid_size=(1, 2),
        titles=['source image', 'segmented image']
    )

    cv2.imwrite(image_name, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))

    sv.plot_images_grid(
        images=masks_this,
        grid_size=(1, 4),
        size=(16, 4)
    )

segmented_image_01 segmented_image_02 segmented_image_03 segmented_image_04 segmented_image_05 segmented_image_06 segmented_image_07 segmented_image_08

I'm not sure if there are settings that could improve the quality of the segmentation and would love to hear suggestions.

Many thanks, Emily

heyoeyo commented 9 months ago

There are a couple options that might help.

  1. Some simple post-processing might be good enough if you just want cleaner looking masks. In particular, morphological filtering (specifically 'closing') can help fill in the gaps in the masks. You can do this fairly easily using opencv (cv2), I think in your case you could do something like (just after you create the masks_this result):
    
    # Set up morphological filter (change ksize to fill in bigger holes)
    ksize = (15,15)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize)

Apply filter to each mask (as uint8) and convert back to boolean to match original data type

cleaner_masks_this = [] for mask in masks_this: mask_uint8 = np.uint8(mask) * 255 new_bool_mask = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel) > 127 cleaner_masks_this.append(new_bool_mask)



2. You can try using the mask input to the predictor. This seems very finicky, but there are posts (see #347 or #169) that suggest you can get better results by iteratively feeding the mask predicted by SAM back into it and re-predicting.

3. If possible, you can try using point prompts instead of the box prompt (i.e. when calling `mask_predictor.predict(...)`). Obviously that's not improving things using boxes, but in case it's an option, it might help. Using both positive and negative prompts tends to help pick up more complex shapes, from what I've seen.
Here's an example of the result (using the web demo) on what seemed like the toughest segmentation:
![multipoint_example](https://github.com/facebookresearch/segment-anything/assets/32405350/67398006-ac74-473a-95d8-adf9bf5556f1)

4. And lastly, if you're planning to do a lot of this kind of segmentation, then using a variant of SAM that is fine tuned for these kinds of images might help. I don't know  anything about this kind of stuff, so I can't be of much help, but a quick search returned [CellSam](https://gist.github.com/sushmanthreddy/618e642d2adfc6b58b6b5df0e9dbd3cd) which seems vaguely related, and might be useful? Fine-tuning your own variant could be a lot of work, so it's only worthwhile if you're going to be working with a lot of these images.
foreignsand commented 9 months ago

Thank you so much! This is really thorough and helpful!

I've been using positive and negative points to help refine the mask as suggested in option 3. I'll probably need to do something more like 4 in the long run because I will be doing this quite a bit, but for now this is working better!

Again, thanks so much!

Cheers, Emily