Closed nairouzshehata closed 1 month ago
Hi there,
Thank you for your interest in trying out MedSAM2. Regarding your questions:
The pre_CT_MR.py
script provided in this codebase is by default set up for the FLARE22 dataset, which contains 50 cases (so 40 cases for training and 10 for testing, hence the hard-coding). Since you are using your custom dataset (don't forget to adjust the window width/level according to anatomy if it's CT) with 200 cases, you can make your splits by replacing lines 182-183 with the following:
tr_len = int(0.8*len(names))
tr_names = names[:tr_len]
ts_names = names[tr_len:]
Alternatively, use scikit-learn's train_test_split
with names
.
Yes, you can create a custom script to convert nifiti files to npz files (pre_CT_MR.py
would be a good reference to start with). The tricky part in your task is that you are trying to infer on a test set with no annotations. SAM models typically require prompt inputs (in the case of MedSAM/MedSAM2, bounding box coordinates) besides images. In our evaluation, we simulated bounding boxes from the annotated voxels by taking [x_min, y_min, x_max, y_max]
of the annotated slice with the largest area of the organ/lesion of interest, and performed inference from that slice upwards and downwards. This is also how MedSAM's 3D Slicer plugin works.
For your case, I suggest manually annotating at least one slice with the largest organ/lesion area (you can do it with the MedSAM 3D slicer plugin), obtaining the upper and lower boundaries along the z-axis, and then performing inference with boxes generated from that annotated slice. If you run into any issues in doing this, feel free to post them and we will see how to proceed from there.
Great, thanks! So I converted nifti to npz for the test set and made sure the keys "imgs", "gts" and "spacing" were populated, code below. I used the predicted masks by totalsegmentator for "gts", I assume that would then be used for the bbox? I just want to improve these predcited masks using MedSAM2. However, I'm getting noisy predicted masks (nothing near an aorta) so I must be doing something wrong?
python pre_CT_MR.py \
-img_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTr \
-img_name_suffix _0000.nii.gz \
-gt_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/labelsTr \
-gt_name_suffix .nii.gz \
-output_path /path/to/nnUNet_preprocessed \
-num_workers 4 \
-modality MR \
-anatomy Aorta \
--save_nii
Then, for npz_to_npy I used the generated npz_train and npy folders (from above) as input to finetune_sam2_img.py
python npz_to_npy.py \
-npz_dir /path/to/nnUNet_preprocessed/npz_train/MR_Aorta \
-npy_dir /path/to/nnUNet_preprocessed/npy \
-num_workers 4
Then fine tuning:
python finetune_sam2_img.py \
-i /path/to/nnUNet_preprocessed/npy \
-task_name MedSAM2-Tiny-Aorta \
-work_dir ./work_dir \
-batch_size 16 \
-pretrain_model_path ./checkpoints/sam2_hiera_tiny.pt \ _(this one is downloaded)_
-model_cfg sam2_hiera_t.yaml
Then the below to convert nifti to npy as mentioned:
import os
import nibabel as nib
import numpy as np
def convert_nifti_to_npz(nifti_dir, mask_dir, npz_dir):
if not os.path.exists(npz_dir):
os.makedirs(npz_dir)
for filename in os.listdir(nifti_dir):
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
# Extract the subject ID from the filename
subject_id = filename.split('_')[1]
nifti_path = os.path.join(nifti_dir, filename)
nifti_img = nib.load(nifti_path)
nifti_data = nifti_img.get_fdata()
voxel_spacing = nifti_img.header.get_zooms() # Get voxel dimensions
# Construct the path to the mask file
mask_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')
if not os.path.exists(mask_path):
print(f"Mask not found for subject {subject_id}, skipping...")
continue
mask_img = nib.load(mask_path)
mask_data = mask_img.get_fdata()
# Binarize the mask (convert to 0s and 1s)
mask_data = (mask_data > 0).astype(np.uint8)
base_filename = os.path.splitext(os.path.splitext(filename)[0])[0]
npz_path = os.path.join(npz_dir, base_filename + '.npz')
np.savez_compressed(npz_path, imgs=nifti_data, spacing=voxel_spacing, gts=mask_data) # Save with 'gts' key
print(f'Converted {filename} to {npz_path} with spacing {voxel_spacing} and mask')
nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/path/to/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATOR_segmentations'
npz_directory = '/path/to/nnUNet_preprocessed/npz_infer/MR_Aorta'
convert_nifti_to_npz(nifti_directory, mask_directory, npz_directory)
Then these are all my failed attempts to infer (tried vanilla medsam, fine-tuned and a random file I found under exp_inference): I only kept label_id for Aorta
1. Attempt using infer_medsam2_flare22 - fine-tuned:
python infer_medsam2_flare22.py \
-data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
-pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam/segs \
-sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
-medsam2_checkpoint ./work_dir/MedSAM2-Tiny-Aorta-20240812-1722/medsam_model_best.pth \
-model_cfg sam2_hiera_t.yaml \
-bbox_shift 5 \
-num_workers 10 \
--visualize
2. Attempt using infer_sam2_flare22 - vanilla sam2:
python infer_sam2_flare22.py \
-data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
-pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam/segs_vanilla \
-sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
-model_cfg sam2_hiera_t.yaml \
-bbox_shift 5 \
-num_workers 10 \
--visualize
3. Attempt using exp_inference/infer_SAM2_3D_npz.py:
python /path/to/MedSAM/exp_inference/infer_SAM2_3D_npz.py \
--imgs_path /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \
--nifti_path /path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs \
--gts /path/to/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATOR_segmentations \
--pred_save_dir /path/to/nnUNet_preprocessed/pred/MR_Aorta \
--checkpoint checkpoints/sam2_hiera_tiny.pt \
--cfg sam2_hiera_t.yaml \
--save_nifti
Thank you for reading that far! 😄
Hi there,
If I understood your code correctly, convert_nifti_to_npz
is going to convert your external validation set (totalsegmentator) from nifiti to npz files. However, it seems that the preprocessing step is missing in your implementation: for MR images, the intensity values need to be first clipped to the range between the 0.5th and 99.5th percentiles and then rescaled to the range of [0, 255]
:
https://github.com/bowang-lab/MedSAM/blob/6abb0ad78a335cecc3b4f5b2d43c4c4ff33fb436/pre_CT_MR.py#L150-L162
Since the training data has undergone this preprocessing step while the external validation set has not, it's kind of expected that both the vanilla SAM2 and the fine-tuned model may not perform well on the external data.
Meanwhile, have you tried running inference on the internal validation set (the testing split from pre_CTMR.py
)? It could provide a useful reference point to compare against the external validation results.
It might also be helpful if you could share the fine-tuning loss curve (which can be found under work_dir
) and the data_sanitycheck.png
generated in the main directory during fine-tuning.
Thank you for your reply! That's a random sample from the internal validation. The images are rotated and cropped but so are the masks so I guess that's not a problem. The quality of the segmentation is not that great though as you can see there's bits of the spine..
versus the _gt I modified the code as advised so will try again with the external validation set.
import os
import nibabel as nib
import numpy as np
def convert_nifti_to_npz(nifti_dir, mask_dir, npz_dir):
if not os.path.exists(npz_dir):
os.makedirs(npz_dir)
for filename in os.listdir(nifti_dir):
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
# Extract the subject ID from the filename
subject_id = filename.split('_')[1]
nifti_path = os.path.join(nifti_dir, filename)
nifti_img = nib.load(nifti_path)
nifti_data = nifti_img.get_fdata()
voxel_spacing = nifti_img.header.get_zooms() # Get voxel dimensions
# Apply intensity clipping and normalization
non_zero_pixels = nifti_data[nifti_data > 0] # Exclude background
lower_bound = np.percentile(non_zero_pixels, 0.5)
upper_bound = np.percentile(non_zero_pixels, 99.5)
nifti_data_clipped = np.clip(nifti_data, lower_bound, upper_bound)
nifti_data_normalized = (nifti_data_clipped - np.min(nifti_data_clipped)) / (np.max(nifti_data_clipped) - np.min(nifti_data_clipped)) * 255.0
nifti_data_normalized[nifti_data == 0] = 0 # Preserve background
nifti_data_normalized = np.uint8(nifti_data_normalized)
# Construct the path to the mask file
mask_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')
if not os.path.exists(mask_path):
print(f"Mask not found for subject {subject_id}, skipping...")
continue
mask_img = nib.load(mask_path)
mask_data = mask_img.get_fdata()
# Binarize the mask (convert to 0s and 1s)
mask_data = (mask_data > 0).astype(np.uint8)
base_filename = os.path.splitext(os.path.splitext(filename)[0])[0]
npz_path = os.path.join(npz_dir, base_filename + '.npz')
np.savez_compressed(npz_path, imgs=nifti_data_normalized, spacing=voxel_spacing, gts=mask_data) # Save with 'gts' key
print(f'Converted {filename} to {npz_path} with spacing {voxel_spacing} and mask')
# Example usage
nifti_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_raw/Dataset002_HVols/TOTALSEGMENTATORsegmentations'
npz_directory = '/vol/biomedic3/nsm116/nnUNet/nnUNet_preprocessed/npz_infer/MR_Aorta'
convert_nifti_to_npz(nifti_directory, mask_directory, npz_directory)
That's the data_sanitycheck.png
MEDSAM2_Tiny_Aortatrain_loss.png
Hi there,
Thanks for sharing the screenshots. They are very informative and have made the root of the issue clearer now. The top-left window in ITK-Snap ideally should display the axial view, but it is showing the sagittal view instead. The same goes for data_sanitycheck.png
, which is showing the sagittal view as well. Maybe the dataset had been originally prepared with SimpleITK and then was preprocessed with nibabel. SimpleITK and nibabel indexing access is in opposite order (we used SimpleITK in our data pipeline). For SimpleITK, it goes [axial, coronal, sagittal]
, while for nibabel it goes [sagittal, coronal, axial]
. For CT/MRI scans, we only use the axial view because CT/MRI scans typically rely on the axial view for annotations. It would be a sensible idea to re-orient the images first (i.e. from [sagittal, coronal, axial]
to [axial, coronal, sagittal]
) before restarting any fine-tuning or inference.
Once the image orientation is fixed, I recommend setting the tumor_id
variable in pre_CT_MR.py
to 1, which corresponds to the label value for aorta in your case (it is called tumor_id
but it is actually used to convert class labels to instance labels). Also add the same code snippet to your nifti-to-npz script for consistency in label conversion for the external set as well:
https://github.com/bowang-lab/MedSAM/blob/c6dcd24ce143d861772740dc6b106dff0b79ad6d/pre_CT_MR.py#L61-L71
The reason for this suggestion becomes apparent when looking at the overlay image on the right-hand side in your data_sanitycheck.png
. You'll notice a single bounding box is enclosing two segmentation instances, which can introduce ambiguity in segmenting targets. Ideally, we want one bounding box to correspond to a single instance target during both training and inference.
It now looks like that, after I've updated the code as below and started the fine-tuning using reoriented dataset
import os
import SimpleITK as sitk
import nibabel as nib
import numpy as np
def reorient_image(image_path):
# Load image with SimpleITK
sitk_image = sitk.ReadImage(image_path)
# Reorient the image to [axial, coronal, sagittal]
reoriented_image = sitk.DICOMOrient(sitk_image, 'LPS')
# Convert SimpleITK image to Nibabel image (which uses [sagittal, coronal, axial] ordering)
reoriented_image_array = sitk.GetArrayFromImage(reoriented_image)
reoriented_image_nib = nib.Nifti1Image(reoriented_image_array, np.eye(4))
return reoriented_image_nib
def convert_and_reorient(nifti_dir, mask_dir, nifti_output_dir, mask_output_dir):
# Create output directories if they don't exist
if not os.path.exists(nifti_output_dir):
os.makedirs(nifti_output_dir)
if not os.path.exists(mask_output_dir):
os.makedirs(mask_output_dir)
for filename in os.listdir(nifti_dir):
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
# Process and save the reoriented image
nifti_path = os.path.join(nifti_dir, filename)
reoriented_nifti = reorient_image(nifti_path)
nifti_output_path = os.path.join(nifti_output_dir, filename)
nib.save(reoriented_nifti, nifti_output_path)
print(f'Reoriented and saved image {filename} to {nifti_output_path}')
# Construct the corresponding mask filename
base_name = filename.split('_0000.nii.gz')[0]
mask_filename = base_name + '.nii.gz'
mask_path = os.path.join(mask_dir, mask_filename)
# Case 1: Mask is in the same directory with a matching filename
if os.path.exists(mask_path):
reoriented_mask = reorient_image(mask_path)
mask_output_path = os.path.join(mask_output_dir, mask_filename)
nib.save(reoriented_mask, mask_output_path)
print(f'Reoriented and saved mask {mask_filename} to {mask_output_path}')
# Case 2: Mask is inside a subfolder named after the subject_id
else:
subject_id = base_name.split('_')[1] # Adjust this index based on your naming convention
mask_subdir_path = os.path.join(mask_dir, subject_id, 'aorta.nii.gz')
if os.path.exists(mask_subdir_path):
reoriented_mask = reorient_image(mask_subdir_path)
mask_output_path = os.path.join(mask_output_dir, mask_filename)
nib.save(reoriented_mask, mask_output_path)
print(f'Reoriented and saved mask {subject_id}/aorta.nii.gz to {mask_output_path}')
else:
print(f'Mask not found for {filename} in both cases, skipping...')
# Example usage (run this twice: once for training pairs and another for validation pairs)
# nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTr'
# mask_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/labelsTr'
# nifti_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/imagesTr'
# mask_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/labelsTr'
nifti_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam/imagesTs'
mask_directory = '/path/to/nnUNet_raw/Dataset002_HVols/segmentations'
nifti_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/imagesTs'
mask_output_directory = '/path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/segmentations_reoriented'
convert_and_reorient(nifti_directory, mask_directory, nifti_output_directory, mask_output_directory)
By the way when I view the raw data in ITK-snap it looks like that The reoriented looks like that now
It is still not done but that's the loss curve so far using the reoriented dataset
That's the _gt by Totalsegmentator: That's the predicted mask by fine-tuned MedSAM Below are the only changes I made to infer_medsam2_flare22.py
label_dict = {
1: 'Aorta'
}
I was getting an error "RuntimeError: Input type (double) and bias type (float) should be the same"
so added
img_1024_tensor = img_1024_tensor.float() # Convert to float32
I also got another error "TypeError: in method 'Image_SetSpacing', argument 2 of type 'std::vector< double,std::allocator< double > > const &'" So I added the below lines for seg_sitk, img_sitk and gts_sitk
spacing = npz['spacing']
# Convert to a Python list
if isinstance(spacing, np.ndarray):
spacing = spacing.tolist()
# Ensure it's a list of floats
spacing = [float(x) for x in spacing]
# Set the spacing
gts_sitk.SetSpacing(spacing)
That's what I used for inference:
python infer_medsam2_flare22.py \
-data_root /path/to/nnUNet_preprocessed/npz_infer/MR_Aorta \ #produced by modified pre_CT_MR.py
-pred_save_dir /path/to/nnUNet_raw/Dataset005_HVolsMedSam_reoriented/medsam_segs_fixed \
-sam2_checkpoint checkpoints/sam2_hiera_tiny.pt \
-medsam2_checkpoint ./work_dir/MedSAM2-Tiny-Aorta-20240817-1010/medsam_model_best.pth \
-model_cfg sam2_hiera_t.yaml \
-bbox_shift 10 \
-num_workers 10 \
--visualize
@ff98li any idea what else can I try? Thanks!
Hi @nairouzshehata , sorry for the late response. I probably misused the term "orientation," which caused some confusion; by orientation, I meant axis ordering. So DICOMOrient
won't give you what you want, as shown in the updated data_sanitycheck.png (so are in your preprocessed data displayed in the posted ITK-Snap screenshots). The preprocessed images are still in the sagittal view along the z-axis but rotated by 90 degrees (because of DICOMOrient
). There are two ways you can achieve the expected axis ordering if you stick to Nibabel in your pipeline for the external validation:
Load images with SimpleITK first, then operate with nibabel:
def reorient_image(image_path):
# Load image with SimpleITK
sitk_image = sitk.ReadImage(image_path)
# Reorient the image to [axial, coronal, sagittal]
# reoriented_image = sitk.DICOMOrient(sitk_image, 'LPS')
pa = sitk.PermuteAxesImageFilter() ## this applies the axis shift to images' metadata as well
pa.SetOrder([2,0,1])
reoriented_image = pa.Execute(sitk_image)
# Convert SimpleITK image to Nibabel image (which uses [sagittal, coronal, axial] ordering)
spacing = reoriented_image.GetSpacing()
origin = reoriented_image.GetOrigin()
direction = reoriented_image.GetDirection()
affine = np.eye(4)
affine[:3, :3] = np.reshape(direction, (3, 3)) * np.asarray(spacing)
affine[:3, 3] = origin
reoriented_image_array = sitk.GetArrayFromImage(reoriented_image)
reoriented_image_nib = nib.Nifti1Image(reoriented_image_array, affine)
return reoriented_image_nib
BTW, the raw data actually have the expected axis-ordering. It's odd that data_sanitycheck is showing the sagittal view. Were there any additional preprocessing steps done before you ran pre_CT_MR.py
? Ideally pre_CT_MR.py
would take the raw images and preserve the raw data's axis-ordering for the output training and internal validation set.
ok I get it now. Yeah I was thinking the same, the raw data in that case are good. I haven't touched them no, just straight to pre_CT_MR.py
. Do you know any other viewers than ITK-snap to confirm axis-ordering. Maybe ITK-snap takes care of it? or to confirm via some code?
In the meantime, I'll give the code snippet you've shared a go and will get back to you.
Thank you so much for following through!
Yes, thank you! It finally worked. I used nibabel to reorient to LPS and It did the trick!
Hello, Thank you for providing detailed steps for fine-tuning. I followed these exactly but have two questions in bold below:
For
pre_CT_MR.py
I used the path to the images folder and their labels folder (200 samples). Then two subfolders were generated, npz_train and npz_test with 40 samples and 160 samples, shouldn't these be the other way around? I just checked the code and you have this hard-coded to 40Then, for npz_to_npy I used the generated npz_train and got another folder npy and this npy was input to finetune_sam2_img.py
Now I have an untouched folder of test samples (images with no labels). How do I infer these? Shall I write a custom code to convert nifti to npz then use that as input to infer_medsam2_flare22.py ?
Thanks!