bowang-lab / MedSAM

Segment Anything in Medical Images
https://www.nature.com/articles/s41467-024-44824-z
Apache License 2.0
2.78k stars 374 forks source link

MedSAM2 datasplit and inference #295

Open nairouzshehata opened 1 month ago

nairouzshehata commented 1 month ago

Hello, Thank you for providing detailed steps for fine-tuning. I followed these exactly but have two questions in bold below:

Forpre_CT_MR.py I used the path to the images folder and their labels folder (200 samples). Then two subfolders were generated, npz_train and npz_test with 40 samples and 160 samples, shouldn't these be the other way around? I just checked the code and you have this hard-coded to 40

Then, for npz_to_npy I used the generated npz_train and got another folder npy and this npy was input to finetune_sam2_img.py

Now I have an untouched folder of test samples (images with no labels). How do I infer these? Shall I write a custom code to convert nifti to npz then use that as input to infer_medsam2_flare22.py ?

Thanks!

ff98li commented 4 weeks ago

Hi there,

Thank you for your interest in trying out MedSAM2. Regarding your questions:

  1. The pre_CT_MR.py script provided in this codebase is by default set up for the FLARE22 dataset, which contains 50 cases (so 40 cases for training and 10 for testing, hence the hard-coding). Since you are using your custom dataset (don't forget to adjust the window width/level according to anatomy if it's CT) with 200 cases, you can make your splits by replacing lines 182-183 with the following:

    tr_len = int(0.8*len(names))
    tr_names = names[:tr_len]
    ts_names =  names[tr_len:]

    Alternatively, use scikit-learn's train_test_split with names.

  2. Yes, you can create a custom script to convert nifiti files to npz files (pre_CT_MR.py would be a good reference to start with). The tricky part in your task is that you are trying to infer on a test set with no annotations. SAM models typically require prompt inputs (in the case of MedSAM/MedSAM2, bounding box coordinates) besides images. In our evaluation, we simulated bounding boxes from the annotated voxels by taking [x_min, y_min, x_max, y_max] of the annotated slice with the largest area of the organ/lesion of interest, and performed inference from that slice upwards and downwards. This is also how MedSAM's 3D Slicer plugin works. For your case, I suggest manually annotating at least one slice with the largest organ/lesion area (you can do it with the MedSAM 3D slicer plugin), obtaining the upper and lower boundaries along the z-axis, and then performing inference with boxes generated from that annotated slice. If you run into any issues in doing this, feel free to post them and we will see how to proceed from there.

nairouzshehata commented 4 weeks ago

Great, thanks! So I converted nifti to npz for the test set and made sure the keys "imgs", "gts" and "spacing" were populated, code below. I used the predicted masks by totalsegmentator for "gts", I assume that would then be used for the bbox? I just want to improve these predcited masks using MedSAM2. However, I'm getting noisy predicted masks (nothing near an aorta) so I must be doing something wrong?

 python pre_CT_MR.py \
     -img_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTr \
     -img_name_suffix _0000.nii.gz \
     -gt_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/labelsTr \
     -gt_name_suffix .nii.gz \
     -output_path /path/to/nnUNet_preprocessed \
     -num_workers 4 \
     -modality MR \
     -anatomy Aorta \
     --save_nii

Then, for npz_to_npy I used the generated npz_train and npy folders (from above) as input to finetune_sam2_img.py

python npz_to_npy.py \
   -npz_dir /path/to/nnUNet_preprocessed/npz_train/MR_Aorta \
   -npy_dir /path/to/nnUNet_preprocessed/npy \
   -num_workers 4   

Then fine tuning:

python finetune_sam2_img.py \
   -i /path/to/nnUNet_preprocessed/npy \
   -task_name MedSAM2-Tiny-Aorta \
   -work_dir ./work_dir \
   -batch_size 16 \
   -pretrain_model_path ./checkpoints/sam2_hiera_tiny.pt \ _(this one is downloaded)_
   -model_cfg sam2_hiera_t.yaml

Then the below to convert nifti to npy as mentioned:

import os
import nibabel as nib
import numpy as np

def convert_nifti_to_npz(nifti_dir, mask_dir, npz_dir):
   if not os.path.exists(npz_dir):
       os.makedirs(npz_dir)

   for filename in os.listdir(nifti_dir):
       if filename.endswith('.nii') or filename.endswith('.nii.gz'):
           # Extract the subject ID from the filename
           subject_id = filename.split('_')[1]
           nifti_path = os.path.join(nifti_dir, filename)
           nifti_img = nib.load(nifti_path)
           nifti_data = nifti_img.get_fdata()
           voxel_spacing = nifti_img.header.get_zooms()  # Get voxel dimensions

           # Construct the path to the mask file
           mask_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')
           if not os.path.exists(mask_path):
               print(f"Mask not found for subject {subject_id}, skipping...")
               continue

           mask_img = nib.load(mask_path)
           mask_data = mask_img.get_fdata()

           # Binarize the mask (convert to 0s and 1s)
           mask_data = (mask_data > 0).astype(np.uint8)

           base_filename = os.path.splitext(os.path.splitext(filename)[0])[0]
           npz_path = os.path.join(npz_dir, base_filename + '.npz')
           np.savez_compressed(npz_path, imgs=nifti_data, spacing=voxel_spacing, gts=mask_data)  # Save with 'gts' key

           print(f'Converted {filename} to {npz_path} with spacing {voxel_spacing} and mask')

nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/path/to/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATOR_segmentations'
npz_directory = '/path/to/nnUNet_preprocessed/npz_infer/MR_Aorta'
convert_nifti_to_npz(nifti_directory, mask_directory, npz_directory)

Then these are all my failed attempts to infer (tried vanilla medsam, fine-tuned and a random file I found under exp_inference): I only kept label_id for Aorta

1. Attempt using infer_medsam2_flare22 - fine-tuned:

python infer_medsam2_flare22.py \
     -data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
     -pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam/segs \
     -sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
     -medsam2_checkpoint ./work_dir/MedSAM2-Tiny-Aorta-20240812-1722/medsam_model_best.pth \
     -model_cfg sam2_hiera_t.yaml \
     -bbox_shift 5 \
     -num_workers 10 \
     --visualize

2. Attempt using infer_sam2_flare22 - vanilla sam2:

python infer_sam2_flare22.py \
   -data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
   -pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam/segs_vanilla \
   -sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
   -model_cfg sam2_hiera_t.yaml \
   -bbox_shift 5 \
   -num_workers 10 \
   --visualize

3. Attempt using exp_inference/infer_SAM2_3D_npz.py:

python /path/to/MedSAM/exp_inference/infer_SAM2_3D_npz.py \
     --imgs_path /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
     --nifti_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs \
     --gts /path/to/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATOR_segmentations \
     --pred_save_dir /path/to/nnUNet_preprocessed/pred/MR_Aorta \
     --checkpoint checkpoints/sam2_hiera_tiny.pt \
     --cfg sam2_hiera_t.yaml \
     --save_nifti 

Thank you for reading that far! 😄

ff98li commented 4 weeks ago

Hi there,

If I understood your code correctly, convert_nifti_to_npz is going to convert your external validation set (totalsegmentator) from nifiti to npz files. However, it seems that the preprocessing step is missing in your implementation: for MR images, the intensity values need to be first clipped to the range between the 0.5th and 99.5th percentiles and then rescaled to the range of [0, 255]: https://github.com/bowang-lab/MedSAM/blob/6abb0ad78a335cecc3b4f5b2d43c4c4ff33fb436/pre_CT_MR.py#L150-L162 Since the training data has undergone this preprocessing step while the external validation set has not, it's kind of expected that both the vanilla SAM2 and the fine-tuned model may not perform well on the external data.

Meanwhile, have you tried running inference on the internal validation set (the testing split from pre_CTMR.py)? It could provide a useful reference point to compare against the external validation results. It might also be helpful if you could share the fine-tuning loss curve (which can be found under work_dir) and the data_sanitycheck.png generated in the main directory during fine-tuning.

nairouzshehata commented 4 weeks ago

Thank you for your reply! That's a random sample from the internal validation. The images are rotated and cropped but so are the masks so I guess that's not a problem. The quality of the segmentation is not that great though as you can see there's bits of the spine..

image

versus the _gt image I modified the code as advised so will try again with the external validation set.

import os
import nibabel as nib
import numpy as np

def convert_nifti_to_npz(nifti_dir, mask_dir, npz_dir):
    if not os.path.exists(npz_dir):
        os.makedirs(npz_dir)

    for filename in os.listdir(nifti_dir):
        if filename.endswith('.nii') or filename.endswith('.nii.gz'):
            # Extract the subject ID from the filename
            subject_id = filename.split('_')[1]
            nifti_path = os.path.join(nifti_dir, filename)
            nifti_img = nib.load(nifti_path)
            nifti_data = nifti_img.get_fdata()
            voxel_spacing = nifti_img.header.get_zooms()  # Get voxel dimensions

            # Apply intensity clipping and normalization
            non_zero_pixels = nifti_data[nifti_data > 0]  # Exclude background
            lower_bound = np.percentile(non_zero_pixels, 0.5)
            upper_bound = np.percentile(non_zero_pixels, 99.5)
            nifti_data_clipped = np.clip(nifti_data, lower_bound, upper_bound)
            nifti_data_normalized = (nifti_data_clipped - np.min(nifti_data_clipped)) / (np.max(nifti_data_clipped) - np.min(nifti_data_clipped)) * 255.0
            nifti_data_normalized[nifti_data == 0] = 0  # Preserve background
            nifti_data_normalized = np.uint8(nifti_data_normalized)

            # Construct the path to the mask file
            mask_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')
            if not os.path.exists(mask_path):
                print(f"Mask not found for subject {subject_id}, skipping...")
                continue

            mask_img = nib.load(mask_path)
            mask_data = mask_img.get_fdata()

            # Binarize the mask (convert to 0s and 1s)
            mask_data = (mask_data > 0).astype(np.uint8)

            base_filename = os.path.splitext(os.path.splitext(filename)[0])[0]
            npz_path = os.path.join(npz_dir, base_filename + '.npz')
            np.savez_compressed(npz_path, imgs=nifti_data_normalized, spacing=voxel_spacing, gts=mask_data)  # Save with 'gts' key

            print(f'Converted {filename} to {npz_path} with spacing {voxel_spacing} and mask')

# Example usage
nifti_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATORsegmentations'
npz_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_preprocessed/npz_infer/MR_Aorta'
convert_nifti_to_npz(nifti_directory, mask_directory, npz_directory)

That's the data_sanitycheck.png image MEDSAM2_Tiny_Aortatrain_loss.png image

ff98li commented 4 weeks ago

Hi there,

Thanks for sharing the screenshots. They are very informative and have made the root of the issue clearer now. The top-left window in ITK-Snap ideally should display the axial view, but it is showing the sagittal view instead. The same goes for data_sanitycheck.png, which is showing the sagittal view as well. Maybe the dataset had been originally prepared with SimpleITK and then was preprocessed with nibabel. SimpleITK and nibabel indexing access is in opposite order (we used SimpleITK in our data pipeline). For SimpleITK, it goes [axial, coronal, sagittal], while for nibabel it goes [sagittal, coronal, axial]. For CT/MRI scans, we only use the axial view because CT/MRI scans typically rely on the axial view for annotations. It would be a sensible idea to re-orient the images first (i.e. from [sagittal, coronal, axial] to [axial, coronal, sagittal]) before restarting any fine-tuning or inference.

Once the image orientation is fixed, I recommend setting the tumor_id variable in pre_CT_MR.py to 1, which corresponds to the label value for aorta in your case (it is called tumor_id but it is actually used to convert class labels to instance labels). Also add the same code snippet to your nifti-to-npz script for consistency in label conversion for the external set as well: https://github.com/bowang-lab/MedSAM/blob/c6dcd24ce143d861772740dc6b106dff0b79ad6d/pre_CT_MR.py#L61-L71

The reason for this suggestion becomes apparent when looking at the overlay image on the right-hand side in your data_sanitycheck.png. You'll notice a single bounding box is enclosing two segmentation instances, which can introduce ambiguity in segmenting targets. Ideally, we want one bounding box to correspond to a single instance target during both training and inference.

nairouzshehata commented 3 weeks ago

It now looks like that, after I've updated the code as below and started the fine-tuning using reoriented dataset data_sanitycheck

import os
import SimpleITK as sitk
import nibabel as nib
import numpy as np

def reorient_image(image_path):
    # Load image with SimpleITK
    sitk_image = sitk.ReadImage(image_path)

    # Reorient the image to [axial, coronal, sagittal]
    reoriented_image = sitk.DICOMOrient(sitk_image, 'LPS')

    # Convert SimpleITK image to Nibabel image (which uses [sagittal, coronal, axial] ordering)
    reoriented_image_array = sitk.GetArrayFromImage(reoriented_image)
    reoriented_image_nib = nib.Nifti1Image(reoriented_image_array, np.eye(4))

    return reoriented_image_nib

def convert_and_reorient(nifti_dir, mask_dir, nifti_output_dir, mask_output_dir):
    # Create output directories if they don't exist
    if not os.path.exists(nifti_output_dir):
        os.makedirs(nifti_output_dir)
    if not os.path.exists(mask_output_dir):
        os.makedirs(mask_output_dir)

    for filename in os.listdir(nifti_dir):
        if filename.endswith('.nii') or filename.endswith('.nii.gz'):
            # Process and save the reoriented image
            nifti_path = os.path.join(nifti_dir, filename)
            reoriented_nifti = reorient_image(nifti_path)
            nifti_output_path = os.path.join(nifti_output_dir, filename)
            nib.save(reoriented_nifti, nifti_output_path)
            print(f'Reoriented and saved image {filename} to {nifti_output_path}')

            # Construct the corresponding mask filename
            base_name = filename.split('_0000.nii.gz')[0]
            mask_filename = base_name + '.nii.gz'
            mask_path = os.path.join(mask_dir, mask_filename)

            # Case 1: Mask is in the same directory with a matching filename
            if os.path.exists(mask_path):
                reoriented_mask = reorient_image(mask_path)
                mask_output_path = os.path.join(mask_output_dir, mask_filename)
                nib.save(reoriented_mask, mask_output_path)
                print(f'Reoriented and saved mask {mask_filename} to {mask_output_path}')

            # Case 2: Mask is inside a subfolder named after the subject_id
            else:
                subject_id = base_name.split('_')[1]  # Adjust this index based on your naming convention
                mask_subdir_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')

                if os.path.exists(mask_subdir_path):
                    reoriented_mask = reorient_image(mask_subdir_path)
                    mask_output_path = os.path.join(mask_output_dir, mask_filename)
                    nib.save(reoriented_mask, mask_output_path)
                    print(f'Reoriented and saved mask {subject_id}/aorta.nii.gz to {mask_output_path}')
                else:
                    print(f'Mask not found for {filename} in both cases, skipping...')

# Example usage (run this twice: once for training pairs and another for validation pairs)
# nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTr'
# mask_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/labelsTr'
# nifti_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/imagesTr'
# mask_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/labelsTr'
nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/path/to/nnUNet_raw/Dataset002_HVols/segmentations'
nifti_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/imagesTs'
mask_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/segmentations_reoriented'

convert_and_reorient(nifti_directory, mask_directory, nifti_output_directory, mask_output_directory)

By the way when I view the raw data in ITK-snap it looks like that image The reoriented looks like that now image

It is still not done but that's the loss curve so far using the reoriented dataset image

nairouzshehata commented 3 weeks ago

That's the _gt by Totalsegmentator: image That's the predicted mask by fine-tuned MedSAM image Below are the only changes I made to infer_medsam2_flare22.py

label_dict = {
    1: 'Aorta'
}

That's what I used for inference:

python infer_medsam2_flare22.py \
    -data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \ #produced by modified pre_CT_MR.py
    -pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/medsam_segs_fixed \
    -sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
    -medsam2_checkpoint ./work_dir/MedSAM2-Tiny-Aorta-20240817-1010/medsam_model_best.pth \
    -model_cfg sam2_hiera_t.yaml \
    -bbox_shift 10 \
    -num_workers 10 \
    --visualize
nairouzshehata commented 2 weeks ago

@ff98li any idea what else can I try? Thanks!

ff98li commented 2 weeks ago

Hi @nairouzshehata , sorry for the late response. I probably misused the term "orientation," which caused some confusion; by orientation, I meant axis ordering. So DICOMOrient won't give you what you want, as shown in the updated data_sanitycheck.png (so are in your preprocessed data displayed in the posted ITK-Snap screenshots). The preprocessed images are still in the sagittal view along the z-axis but rotated by 90 degrees (because of DICOMOrient). There are two ways you can achieve the expected axis ordering if you stick to Nibabel in your pipeline for the external validation:

  1. The pure nibabel way, as done in nnUNet's NibabelIO (recommended since you prefer nibabel to SimpleITK): https://github.com/MIC-DKFZ/nnUNet/blob/9cd9d80ab3d9542138422f33154ee58421a92088/nnunetv2/imageio/nibabel_reader_writer.py#L42-L54
  2. Load images with SimpleITK first, then operate with nibabel:

    def reorient_image(image_path):
    # Load image with SimpleITK
    sitk_image = sitk.ReadImage(image_path)
    
    # Reorient the image to [axial, coronal, sagittal]
    # reoriented_image = sitk.DICOMOrient(sitk_image, 'LPS')
    pa = sitk.PermuteAxesImageFilter() ## this applies the axis shift to images' metadata as well
    pa.SetOrder([2,0,1])
    reoriented_image = pa.Execute(sitk_image)
    
    # Convert SimpleITK image to Nibabel image (which uses [sagittal, coronal, axial] ordering)
    spacing = reoriented_image.GetSpacing()
    origin = reoriented_image.GetOrigin()
    direction = reoriented_image.GetDirection()
    affine = np.eye(4)
    affine[:3, :3] = np.reshape(direction, (3, 3)) * np.asarray(spacing)
    affine[:3, 3] = origin
    reoriented_image_array = sitk.GetArrayFromImage(reoriented_image)
    reoriented_image_nib = nib.Nifti1Image(reoriented_image_array, affine)
    
    return reoriented_image_nib

    BTW, the raw data actually have the expected axis-ordering. It's odd that data_sanitycheck is showing the sagittal view. Were there any additional preprocessing steps done before you ran pre_CT_MR.py? Ideally pre_CT_MR.py would take the raw images and preserve the raw data's axis-ordering for the output training and internal validation set.

nairouzshehata commented 2 weeks ago

ok I get it now. Yeah I was thinking the same, the raw data in that case are good. I haven't touched them no, just straight to pre_CT_MR.py. Do you know any other viewers than ITK-snap to confirm axis-ordering. Maybe ITK-snap takes care of it? or to confirm via some code?

In the meantime, I'll give the code snippet you've shared a go and will get back to you.

Thank you so much for following through!