facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
11.63k stars 1.01k forks source link

Promptless segmentation #166

Open 25benjaminli opened 2 months ago

25benjaminli commented 2 months ago

Hi all, I am trying to build a pipeline to train without prompts and only use the default sparse & dense embeddings + image embedding. For some reason, the resulting segmentation doesn't seem to do well compared to the demos. Please note that this code is different from the inference code provided in the repository because my ultimate intention is to train the model. Also, I am using the approach discussed in #138 to input flexible image size.

def iteration(predictor, batch):
    # process WITHOUT any prompts

    batched_mode = True if batch["mask"].ndim == 4 else False # trying no batch first
    if batched_mode:
        predictor.set_image_batch(batch["image"]) # apply SAM image encodet to the image, is it automatically normalized
    else:
        predictor.set_image(batch["image"]) # apply SAM image encodet to the image, is it automatically normalized

    sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=None,boxes=None,masks=None)

    low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
        image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
        image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
        repeat_image=batched_mode,
        high_res_features=None
    )

    # ! resolve orig_hw thing
    prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

    print("prd_masks", prd_masks.shape, low_res_masks.shape)

    prd_mask = torch.sigmoid(prd_masks[:, 0]) > 0.5
    # prd_mask = torch.sigmoid(prd_masks) > 0.5

    print("final prd masks shape", prd_mask.shape)

    return prd_mask # (channel, img_size, img_size)

Please see this screenshot of the resulting segmentation mask overlaid on top of the original image.

Screenshot 2024-08-06 at 9 26 20 PM

Thanks for the help!

heyoeyo commented 2 months ago

If this is the result without any additional training, I think that's normal. Without prompts, the models generate very scattered/non-specific outputs.

In case it's of any use, a user on the SAM v1 issues had a blog post explaining how they set up a promptless version of the model. It's a different approach (training a custom decoder), but may be a useful reference.