facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
11.15k stars 952 forks source link

Using SAM2 detects few trees #314

Open dnromero opened 2 weeks ago

dnromero commented 2 weeks ago

Hi everyone

I recently started evaluating SAM2 for tree detection. I´d like to clarify that I am new to this whole topic and I´m trying to learn how to use SAM2 to detect trees.

I have tried the following code: import os

if using Apple MPS, fall back to CPU for unsupported ops

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import numpy as np import torch import matplotlib.pyplot as plt from PIL import Image

select the device for computation

if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print(f"using device: {device}")

if device.type == "cuda":

use bfloat16 for the entire notebook

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

elif device.type == "mps": print( "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " "give numerically different outputs and sometimes degraded performance on MPS. " "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." ) np.random.seed(3)

def show_anns(anns, borders=True): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
    m = ann['segmentation']
    color_mask = np.concatenate([np.random.random(3), [0.5]])
    img[m] = color_mask 
    if borders:
        import cv2
        contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

ax.imshow(img)

image = Image.open(r"C:\Users\Lenovo\Desktop\Daniel\AIconteo\cerro2_corte.tif") image = np.array(image.convert("RGB")) plt.figure(figsize=(20, 20)) plt.imshow(image) plt.axis('off') plt.show() from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = r"C:\Users\Lenovo\segment-anything-2\checkpoints\sam2_hiera_base_plus.pt" model_cfg = "sam2_hiera_b+.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2) masks = mask_generator.generate(image) print(len(masks)) print(masks[0].keys()) 121 dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box']) plt.figure(figsize=(20, 20)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.show() example 1 example 2

I add some examples of images. I am already very grateful if can anyone help me? :)

heyoeyo commented 2 weeks ago

For this sort of example, SAM might not be a great option. With so many small, similar looking objects, it might be worth trying simpler methods like correlation (see template matching from scikit-image) or even just thresholding might be a good start (for this specific image at least).

However, you can probably improve the SAM result by adjusting the default settings. Two that seem like they might help are points_per_side and crop_n_layers. The auto-masking works by generating a bunch of single point prompts in a grid, and the points_per_side setting controls how many points are used in that grid. The default setting (32) is probably too low to account for so many objects, so it's worth increasing that (at least 64, maybe even higher?). The crop_n_layers setting is almost like a 'zoom-in' feature, which runs segmentation on smaller cropped parts of the image, which should just help catch some of the smaller trees. This setting slows things down a lot, so see if a setting of 1 helps before setting it any higher.

If you want to see what these options are doing, you can add the following code just after line 266 of the mask generator:

# Visualize point prompts used by the mask generator
import cv2
debug_img = cv2.cvtColor(cropped_im, cv2.COLOR_RGB2BGR)
for xy in points_for_image:
    pt_xy = xy.astype(np.int32).tolist()
    cv2.circle(debug_img, pt_xy, 2, (255,0,255), -1)
cv2.imshow("DebugPoints", debug_img)
cv2.waitKey(0)
cv2.destroyWindow("DebugPoints")

This will create a pop-up image showing the 'crop' that's being used along with the point prompts (it will also pause the masking, but you can press any key with the window open to resume).

One last thing that's probably helpful if you only want the trees is to ignore any overly large masks. You can filter them out using something like:

# Filter out large masks
max_area = 1000
masks = [m for m in masks if m["area"] < max_area]