bowang-lab / MedSAM

Segment Anything in Medical Images
https://www.nature.com/articles/s41467-024-44824-z
Apache License 2.0
2.81k stars 380 forks source link

Provide full configuration for DRIVE vessel segmentation dataset #41

Closed ngctnnnn closed 1 year ago

ngctnnnn commented 1 year ago

Hi, first of all, thanks for your work!

However, when trying to reproduce the result from DRIVE dataset which is about vessel segmentation, I could not achieve the results as you did in Table 2 of your paper. Could you kindly provide me more detailedly about how to produce the DSC of around 66 (my best results are only around 60).

JunMa11 commented 1 year ago

Hi @ngctnnnn ,

Thanks for your interest.

  1. How do you run the experiments? I just tested the model and re-computed the metrics and confirmed that the reported DSC is right. The segmentation results and trained model have been available on the readme page.
  2. We didn't have specific configurations for DRIVE dataset. All the trained datasets are merged together for training and you can find the training data on the readme page as well.
ngctnnnn commented 1 year ago

please correct me if i'm wrong, when I use this code on test 20 images for 20 batches (1 image per batch), my output dice score is 3.5 instead of less than 1. Does this problem occur to you also or just me. Thanks

sam_model = sam_model.to("cuda:3")
        mask_segmentation, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding.to("cuda:3"), # (B, 256, 64, 64)
            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
        )
        medsam_seg_prob = torch.sigmoid(mask_segmentation)
        # convert soft mask to hard mask
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
        all_masks_segmentation.append(medsam_seg)
        all_gt2d.append(gt2D)

    all_masks_segmentation = np.stack(all_masks_segmentation, axis = 0)
    all_gt2d = np.stack(all_gt2d, axis =0)
    dice_score = compute_dice_coefficient(all_gt2d>0, all_masks_segmentation>0)
JunMa11 commented 1 year ago

Hi @ngctnnnn ,

The testing images and code are also available on the read me page. I didn't have this issue.