med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
134 stars 12 forks source link

Some questions about the visualisation results. #24

Open ghost opened 10 months ago

ghost commented 10 months ago

Dear author

Thank you for your excellent work.

I just wonder do you have any methods to generate the visualisation figures in your paper? any code to do that?

yzhong22 commented 10 months ago

Hi there,

Here is the core code of overlaying masks onto the images and generating those visualization figures in our paper. Sorry that we didn't spend time organizing it. It should be usable with some minor modifications.

def load_img_or_seg(mode, case_idx, slice_idx, ROI, img_shape, border_c="yellow"):
    org_img = get_array("img", case_idx, slice_idx)
    idx = np.where(org_img>0)
    margin = 20

    org_img = org_img[max(idx[0].min()-margin, 0):min(idx[0].max()+margin, org_img.shape[0]),
                       max(idx[1].min()-margin, 0):min(idx[1].max()+margin, org_img.shape[1])]
    array = Image.fromarray((org_img * 255).astype(np.uint8)).convert('RGB').resize(img_shape)

    if mode != "img":
        seg = get_array(mode, case_idx, slice_idx)
        seg = seg[max(idx[0].min()-margin, 0):min(idx[0].max()+margin, seg.shape[0]),
                       max(idx[1].min()-margin, 0):min(idx[1].max()+margin, seg.shape[1])]
        seg = np.repeat(seg[:, :, None], repeats=3, axis=-1)
        seg[:, :, 1:] = 0
        seg = Image.fromarray((seg * 255).astype(np.uint8)).resize(img_shape)

    else:
        seg = Image.new(mode="RGBA", size=img_shape)

    array = Image.blend(array.convert("RGBA"), seg.convert("RGBA"), 0.4)
    ImageDraw.Draw(array).rectangle([(ROI[0], ROI[1]), (ROI[0]+ROI[2], ROI[1]+ROI[3])], outline =border_c, width=5)
    array_crop = array.copy().crop((ROI[0], ROI[1], ROI[0]+ROI[2], ROI[1]+ROI[3]))

    if data in ["msd", "colon"]:
        array.paste(array_crop.resize((img_shape[0]//2, img_shape[1]//2)), (0,0))
    else:
        array.paste(array_crop.resize((img_shape[0]//2, img_shape[1]//2)), (0,img_shape[1]-img_shape[1]//2))

    return array
ghost commented 10 months ago

Thank you for your quick response.

But can you please provide me the code of get_array?

yzhong22 commented 10 months ago

Here you go with the function. It basically loads the .nii.gz files of images and masks and converts them to arrays.

def get_array(mode, case_idx, slice_idx):
    if mode == "img":
        array = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(img_root, f"{split_test[case_idx]}_0000.nii.gz")))
        if data in ["msd", "lits", "colon"]:
            array = array[slice_idx]
        else:
            array = array[:, :, slice_idx]
        array = (np.clip(array, a_min=clip_vis[0], a_max=clip_vis[1]) - clip_vis[0] ) / (clip_vis[1]-clip_vis[0])
    elif mode == "gt":
        array = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(gt_root, f"{split_test[case_idx]}.nii.gz")))
        if data in ["msd", "lits", "colon"]:
            array = array[slice_idx]
        else:
            array = array[:, :, slice_idx]
    elif mode == "ours":
        array = np.load(os.path.join(ours_root, data, f"{mode}-1point", f"{case_idx}.npy"))[slice_idx].astype(np.int32)
    elif mode == "sam":
        if data in ["msd"]:
            array = np.load(os.path.join(ours_root, data, f"{mode}-1point", f"{case_idx}.npy"))[slice_idx].astype(np.int32)
        else:
            array = np.load(os.path.join(ours_root, data, f"{mode}-1point", f"{case_idx}.npy"))[0, 0, slice_idx].astype(np.int32)
    elif mode == "nnunet":
        array = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(nnunet_root, data, f"{split_test[case_idx]}.nii.gz")))
        if data in ["msd", "lits", "colon"]:
            array = array[slice_idx]
        else:
            array = array[:, :, slice_idx]
    else:
        array = np.load(os.path.join(baseline_root, data, mode, "predictions", f"{case_idx}.npy"))[slice_idx]

    if mode != "img":
        result = np.zeros((array.shape), np.uint8)
        array = array.astype(np.uint8) * 255
        thresh, im_bw = cv2.threshold(array, 127, 255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(im_bw, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        array = cv2.drawContours(result, contours, -1, 255, 2)
        array = (array / 255).astype(np.int8)

    if data == "kits":
        if mode in ["img", "gt", "nnunet"]:
            array = array.transpose()

        if mode not in []:
            array = np.flip(array)
    elif data == "lits":
        array = np.flip(array)

    return array
ghost commented 10 months ago

Thank you so much

However, what is clip_vis and split_test

yzhong22 commented 10 months ago

Sorry for the confusion. clip_vis is the intensity range we used to clip CTs for visualization, you can refer to here for this hyper-parameter of 4 datasets. And split_test is the split list for the test set, which can be found here.

ghost commented 10 months ago

Thank you so much for replying me, But without a specific example, it is very hard to reproduce the similar visualisation figures. Can you please provide me an example to illustrate how the parameters are defined? such as the slice_idx, ROI, img_shape, border_c

ghost commented 10 months ago

Sorry for asking more. Can you provide me a specific example to show me how to generate the similar visualization figures such as the figure 5 in the paper? I just very hard for me to determine the values of the parameters


From: Yuan Zhong @.> Sent: Tuesday, September 19, 2023 12:50:31 PM To: med-air/3DSAM-adapter @.> Cc: GAO, Bingchen [Student] @.>; Author @.> Subject: Re: [med-air/3DSAM-adapter] Some questions about the visualisation results. (Issue #24)

Sorry for the confusion. clip_vis is the intensity range we used to clip CTs for visualization, you can refer to herehttps://github.com/med-air/3DSAM-adapter/blob/main/3DSAM-adapter/dataset/datasets.py for this hyper-parameter of 4 datasets. And split_test is the split list for the test set, which can be found herehttps://github.com/med-air/3DSAM-adapter/tree/main/datafile.

— Reply to this email directly, view it on GitHubhttps://github.com/med-air/3DSAM-adapter/issues/24#issuecomment-1724827303, or unsubscribehttps://github.com/notifications/unsubscribe-auth/BCUUOQDL3YWZWYZKPEGYNQLX3EQBPANCNFSM6AAAAAA44CYJUM. You are receiving this because you authored the thread.Message ID: @.***>

[https://www.polyu.edu.hk/emaildisclaimer/PolyU_Email_Signature.jpg]

Disclaimer:

This message (including any attachments) contains confidential information intended for a specific individual and purpose. If you are not the intended recipient, you should delete this message and notify the sender and The Hong Kong Polytechnic University (the University) immediately. Any disclosure, copying, or distribution of this message, or the taking of any action based on it, is strictly prohibited and may be unlawful.

The University specifically denies any responsibility for the accuracy or quality of information obtained through University E-mail Facilities. Any views and opinions expressed are only those of the author(s) and do not necessarily represent those of the University and the University accepts no liability whatsoever for any losses or damages incurred or caused to any party as a result of the use of such information.

ghost commented 10 months ago

Can you please tell me how do you save the predicted masks to the .npy file? so that it could be used for later visualisation purpose