cwmok / C2FViT

This is the official Pytorch implementation of "Affine Medical Image Registration with Coarse-to-Fine Vision Transformer" (CVPR 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
129 stars 3 forks source link

About the training process of the model #18

Open RexEmperor opened 4 months ago

RexEmperor commented 4 months ago

Dear @cwmok I sincerely appreciate your excellent work, but I have encountered several issues while reproducing this paper. Firstly, regarding the number of iterations, I found that the iteration number in the code of this paper is 160001. Is this the actual number of iterations during the training process? Secondly, I am currently reproducing the first part of Train-C2FVit_pairwise, but I have found that DiceLoss defined in C2FVit_model and Dice defined in Train-C2FVit_pairwise do not seem to be used during the training process. I would like to know how Dice should be used. Thirdly, after downloading the Code, I saw two columns of data in the C2FViT.affineeCOMpairwise . txt file in the folder log. One column was Dice, and the other column was NCC. Is this correct my understanding? But why does the data obtained in the end not match the data provided by Tabe in the paper?

cwmok commented 4 months ago

Hi @RexEmperor,

Thanks for your interest in our work.

Firstly, regarding the number of iterations, I found that the iteration number in the code of this paper is 160001. Is this the actual number of iterations during the training process?

Yes. After training, we select the best model based on the result of the validation set.

Secondly, I am currently reproducing the first part of Train-C2FVit_pairwise, but I have found that DiceLoss defined in C2FVit_model and Dice defined in Train-C2FVit_pairwise do not seem to be used during the training process. I would like to know how Dice should be used.

Train_C2FViT_pairwise_semi.py and Train_C2FViT_template_matching_semi.py are the semi-supervised models, which will use Dice loss in training. The other scripts are fully unsupervised. See Readme.md for more details.

Thirdly, after downloading the Code, I saw two columns of data in the C2FViT.affineeCOMpairwise . txt file in the folder log. One column was Dice, and the other column was NCC. Is this correct my understanding? But why does the data obtained in the end not match the data provided by Tabe in the paper?

Both columns are Dice. The first column is the dice of the "selected" anatomical structures, while the second column is the dice of the whole brain (binary) semengtation. See Code here for more details. More importantly, the result reported in the log file is the result of the validation dataset. In our paper, we report the result of the test set, which is a disjoint dataset apart from the validation set.

RexEmperor commented 4 months ago

Dear @cwmok Thank you for patient answering. I think I made a mistake. I did not untangle the validation part of the annotation. So, all I need to do is untangle the validation part of the unsupervised training code to obtain results similar to those in the Log file? But I didn't find any part about calculating Dice scores in the test.py file . Do I need to manually add it? By calculating Dice in the Test section, I can obtain results similar to those in the Table in the article, right? Thank you again for your patient answer.

cwmok commented 4 months ago

I think I made a mistake. I did not untangle the validation part of the annotation. So, all I need to do is untangle the validation part of the unsupervised training code to obtain results similar to those in the Log file?

Yes.

But I didn't find any part about calculating Dice scores in the test.py file . Do I need to manually add it?

No, you don't have to. The code for calculating the Dice score is provided.

def dice(im1, atlas):
    unique_class = np.unique(atlas)
    dice = 0
    num_count = 0
    for i in unique_class:
        if (i == 0) or ((im1 == i).sum() == 0) or ((atlas == i).sum() == 0):
            continue

        sub_dice = np.sum(atlas[im1 == i] == i) * 2.0 / (np.sum(im1 == i) + np.sum(atlas == i))
        dice += sub_dice
        num_count += 1
    return dice / num_count

By calculating Dice in the Test section, I can obtain results similar to those in the Table in the article, right?

Yes, you're correct!

RexEmperor commented 4 months ago

Dear @cwmok Thank you very much for your patience and prompt response. You have been a great help to me. I have another question. If I want to obtain the data in the article Table, which Python file do I need to modify it based on? Is it Test-C2FVit_pairwise or Train-C2FVit_pairwise? If the modification is based on Test-C2FVit_pairwise, then I need to apply def dice to Test-C2FVit_pairwise and adopt a structure similar to the validation part of Train-C2FVit_pairwise? If it is a modification based on Train-C2FViT.pairwise, would it be like me to follow the structure of the validation section and add a section for calculating the Dice score for the test? Thank you again for your reply.

cwmok commented 4 months ago

I have another question. If I want to obtain the data in the article Table, which Python file do I need to modify it based on? Is it Test-C2FVit_pairwise or Train-C2FVit_pairwise? If the modification is based on Test-C2FVit_pairwise, then I need to apply def dice to Test-C2FVit_pairwise and adopt a structure similar to the validation part of Train-C2FVit_pairwise?

You will have to create a new python script, similar to the structure of the validation section. Test-C2FVit_pairwise takes only one pair of the images as input and outputs the affine-transformed image. It will not compute the Dice score.

RexEmperor commented 4 months ago

Okay, I roughly understand. Do I need to import the trained model, then use it to register the prediction set and calculate the Dice score? Your timely reply has been of great help to me, thank you very much.

cwmok commented 4 months ago

@RexEmperor Yes, you're correct.

RexEmperor commented 4 months ago

Dear @cwmok I followed the validation section and created a new code file for testing. I would like to ask, what I get will be DSC(23) is it? This is the code I wrote:

` dice_total = [] brain_dice_total = []

    dice_total = np.array(dice_total)
    brain_dice_total = np.array(brain_dice_total)

    print("\nTesting...")

    for batch_idx, data in enumerate(valid_generator):
        X, Y, X_label, Y_label = data[0].to(device), data[1].to(device), data[2].to(
            device), data[3].to(device)

        with torch.no_grad():
            if com_initial:
                X, init_flow = init_center(X, Y)
                X_label = F.grid_sample(X_label, init_flow, mode="nearest", align_corners=True)

            X_down = F.interpolate(X, scale_factor=0.5, mode="trilinear", align_corners=True)
            Y_down = F.interpolate(Y, scale_factor=0.5, mode="trilinear", align_corners=True)

            warpped_x_list, y_list, affine_para_list = model(X_down, Y_down)
            X_Y, affine_matrix = affine_transform(X, affine_para_list[-1])
            F_X_Y = F.affine_grid(affine_matrix, X_label.shape, align_corners=True)

            X_Y_label = F.grid_sample(X_label, F_X_Y, mode="nearest", align_corners=True).cpu().numpy()[0,
                        0, :, :, :]
            X_brain_label = (X_Y > 0).float().cpu().numpy()[0, 0, :, :, :]

            # brain mask
            Y_brain_label = (Y > 0).float().cpu().numpy()[0, 0, :, :, :]
            Y_label = Y_label.data.cpu().numpy()[0, 0, :, :, :]

            dice_score = dice(np.floor(X_Y_label), np.floor(Y_label))
            dice_total = np.append(dice_total,dice_score)

            brain_dice = dice(np.floor(X_brain_label), np.floor(Y_brain_label))
            brain_dice_total = np.append(brain_dice_total,brain_dice)

            dice_total = np.array(dice_total)
            brain_dice_total = np.array(brain_dice_total)

            print("Dice mean: ", dice_total.mean())
            print("Brain Dice mean: ", brain_dice_total.mean())

            with open(log_dir, "a") as log:
                log.write(f"{dice_total.mean()}, {brain_dice_total.mean()} \n")

Could you please give me some advice? I ended up with a result of around 0.59, while in the paper it is around 0.66, which confuses me.

cwmok commented 4 months ago

@RexEmperor

What is your initial DSC23? We selected 23 critical anatomical structures for evaluation, as shown below. image

More importantly, as mentioned in the paper, we randomly select 3 MRI scans as an atlas in the OASIS dataset. It seems that your code is reporting the pairwise registration results but not comparing them to the 3 MRI atlases.

Remember to set com_initial=True and load the correct model ('C2FViT_affine_COM_pairwise_stagelvl3_118000.pth).

RexEmperor commented 4 months ago

Dear @cwmok Sorry, I seem to be getting confused. I want to realize this experiment——Atlas-Based Registration (OASIS). Then, I trained the model based on the file——Train_C2FViT_pairwise.py. I want to get the results of this experiment in Table1 and Table2. 微信图片_20240423173204 Because, as I read in the paper, “In the atlas-based registration with the OASIS dataset, 23 subcortical structures re included”. So I took it for granted that as long as I write a test file like the validation section of the file to use for testing it would be fine, and it seems I was wrong about that now. I thought "Pairwise image registration" was the task——Atlas-Based Registration (OASIS).In hindsight, I was wrong. If I want to do this task——Atlas-Based Registration (OASIS), which step should I do now? Should the model I trained be the one that will eventually be used to perform that task? Is it that I don't need to train another model. Thanks again for your patience. I know I'm bothering you, but I'm really trying to figure out the experiment for this paper, thank you so much yet!

RexEmperor commented 4 months ago

More importantly, as mentioned in the paper, we randomly select 3 MRI scans as an atlas in the OASIS dataset. It seems that your code is reporting the pairwise registration results but not comparing them to the 3 MRI atlases.

So according to what you are saying is that I need to compare one image (moving image) with three randomly selected images (fixed image) and then find their average Dice score?

cwmok commented 4 months ago

@RexEmperor

If I want to do this task——Atlas-Based Registration (OASIS), which step should I do now? Should the model I trained be the one that will eventually be used to perform that task? Is it that I don't need to train another model.

You will have to train two models, one with the centre of mass initialization and one without a centre of mass initialization.

So according to what you are saying is that I need to compare one image (moving image) with three randomly selected images (fixed image) and then find their average Dice score?

You're correct. For each image in the test set, we compare it with three randomly selected images (fixed image) and then compute their average Dice score over 23 selected anatomical structures.

RexEmperor commented 4 months ago

Dear @cwmok About the task of template match, I preprocessed the dataset like you did:.

target_dir = 'path/to/your/directory'  

contents = os.listdir(target_dir)

folder_names = [name for name in contents if os.path.isdir(os.path.join(target_dir, name))]

for name in folder_names:
    dir_path = os.path.join(folder_names,name)
    file_path = os.path.join(file_path,"seg4.nii.gz")

    seg4_lable = nibabel.load(file_path)
    seg4_lable_npy = seg4_lable.get_fdata()

    seg4_lable_oasis = np.zeros(seg4_lable_npy.shape,dtype=seg4_lable_npy.dtype)
    seg4_lable_oasis[(seg4_lable_npy == 1)] = 1
    seg4_lable_oasis[(seg4_lable_npy == 2)] = 2
    seg4_lable_oasis[(seg4_lable_npy == 7)] = 3
    seg4_lable_oasis[(seg4_lable_npy == 9)] = 4

    new_nii = nibabel.Nifti1Image(seg4_lable_oasis,seg4_lable.affine,seg4_lable.header)
    nibabel.save(new_nii,os.path.join(dir_path,"seg4_mni.nii.gz"))

After preprocessing the data according to the above code, I trained my model according to Train_C2FViT_template_matching.py and uncommented the validation part of the code. Then based on my observations, I noticed that the Dice scores for the validation part of the model I trained were way off from the ones in the log file you provided. Here's my Dice score: image And here is the Dice score in your log folder: image This is very confusing to me, what do you suggest about this? Once again, thank you for your tireless replies!

cwmok commented 4 months ago

@RexEmperor

Your preprocessing is for the mni152 altas, but not for the MRI in OASIS dataset.

Correct one:

# for label_path in oasis_label_list:
#     print(label_path)
#     label_nib = nib.load(label_path)
#     label_npy = label_nib.get_fdata()
#
#     label_mni = np.zeros(label_npy.shape, dtype=label_npy.dtype)
#     label_mni[(label_npy==8)] = 1
#     label_mni[(label_npy==27)] = 1
#
#     label_mni[(label_npy==5)] = 2
#     label_mni[(label_npy==6)] = 2
#     label_mni[(label_npy==24)] = 2
#     label_mni[(label_npy==25)] = 2
#
#     label_mni[(label_npy==9)] = 3
#     label_mni[(label_npy==28)] = 3
#
#     label_mni[(label_npy==7)] = 4
#     label_mni[(label_npy==26)] = 4
#
#     new_nii = nib.Nifti1Image(label_mni.astype(np.int64), label_nib.affine, label_nib.header)
#     save_path = label_path.replace("seg35.nii.gz", "seg4_mni.nii.gz")
#     nib.save(new_nii, save_path)

You should always check your preprocessed data before you use it for training. I suggest using 3D Slicer/ITKSnap for visualization check to make sure you're selecting the correct ROI anatomical structures for both moving and fixed images.

RexEmperor commented 4 months ago

Dear @cwmok Thank you again for your help, you have taught me so much that was previously unknown to me. I saw your Q&A with others in your Issue section, so I just called the preprocessing code directly. Thanks again for your reply.

RexEmperor commented 4 months ago

Dear @cwmok I'm so sorry to bother you. I have two questions: One is that in the dataset division, the validation set is 10 volumes, but in the code section, it only includes 5 volumes. imgs = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/norm.nii.gz"))[255:260] The second question is about the testing part of template_matching.I'm wondering if the testing part is consistent with the validation part?Simply align the images in the test set one by one with the MNI template and calculate the Dice score.

cwmok commented 4 months ago

@RexEmperor

One is that in the dataset division, the validation set is 10 volumes, but in the code section, it only includes 5 volumes. imgs = sorted(glob.glob(datapath + "/OASISOAS1*_MR1/norm.nii.gz"))[255:260]

Please strictly follow the result and setting reported in the paper. Do not rely on the commented-out code (it may not be up-to-date.) I acknowledge that the results and setting written in the paper are accurate and correct.

The second question is about the testing part of template_matching.I'm wondering if the testing part is consistent with the validation part?Simply align the images in the test set one by one with the MNI template and calculate the Dice score.

Yes. It is just aligning the images in the test set one by one with the MNI template and calculating the Dice score.

RexEmperor commented 4 months ago

Thanks for the reply, I understand it. Thank you very much for your patience.

RexEmperor commented 4 months ago

Dear @cwmok I've got a problem: The Dice score I get after validating this phase of the model I trained is screwed up by 0.02 compared to what you got, is there something wrong in the middle? Does the MNI template file need to be preprocessed? Or is it ok to use the ones in the Data folder you provided.

cwmok commented 3 months ago

@RexEmperor

I think it is ok to use the one in the Data folder. But you have to make sure the index of each anatomical structure matches that of the moving image. I strongly recommend visualizing both the MNI template and its label using ITKSNAP before training.

RexEmperor commented 3 months ago

image image The first image shows the labeling of the template image and the second image shows the labeling of the pre-processed oasis image. Have I got something wrong? Thank you for your reply.

cwmok commented 3 months ago

The intensity values of the anatomical structures do not seem to match each other. Could you confirm it?

RexEmperor commented 3 months ago

I don't know if this is the right way to see if my intensity values of the anatomical structures are off.

import numpy as np
import nibabel as nib
# Assuming label_npy is the original label array and label_mni is the remapped label array
label_npy = nib.load('H:\data\seg35.nii.gz')
label_npy = label_npy.get_fdata()
label_mni = nib.load('H:\data\seg4_mni_cp.nii.gz')
label_mni = label_mni.get_fdata()
# Finds the index of all elements in label_mni with a value of 1
indices_mapped_to_1 = np.where(label_mni == 1)

# Check that these indexes correspond to values of 8 or 27 in the original label_npy array.
original_values_at_indices = label_npy[indices_mapped_to_1]
has_8_and_27 = (np.any(original_values_at_indices == 8)) & (np.any(original_values_at_indices == 27))

if has_8_and_27:
    print("Labels 8 and 27 are mapped to 1")
else:
    print("Labels 8 and 27 aren't mapped to 1")

I just tested a set that was mapped to 1 and the results were right. image Am I proving it the wrong way?Thank you for your patience in replying.

RexEmperor commented 3 months ago

According to you, it is the MNI templates that are not a problem, since I downloaded these files from you as well. In the code, the MNI152 image is MNI152_T1_1mm_brain_pad_RSP.nii.gz;MNI152 label is MNI-maxprob-thr50-1mm_pad_RSP_oasis.nii.gz.That should be right, right? image image Above is their presentation in ITK-SNAP. If there's no problem with this, then it's only possible that I'm having a problem pre-processing the data? Is my preprocessing code correct please? Just preprocess the oasis images and generate new images as their labels right?

import os
import nibabel as nib
import numpy as np

target_dir = '/home/guoke/Downloads/C2FViT-main/Data/Train_pairwise/'  

contents = os.listdir(target_dir)

oasis_label_list = [os.path.join(target_dir,name) for name in contents if os.path.isdir(os.path.join(target_dir, name))]

for label_path in oasis_label_list:
    label_path = os.path.join(label_path,"seg35.nii.gz")
    print(label_path)
    label_nib = nib.load(label_path)
    label_npy = label_nib.get_fdata()

    label_mni = np.zeros(label_npy.shape, dtype=label_npy.dtype)
    label_mni[(label_npy==8)] = 1
    label_mni[(label_npy==27)] = 1

    label_mni[(label_npy==5)] = 2
    label_mni[(label_npy==6)] = 2
    label_mni[(label_npy==24)] = 2
    label_mni[(label_npy==25)] = 2

    label_mni[(label_npy==9)] = 3
    label_mni[(label_npy==28)] = 3

    label_mni[(label_npy==7)] = 4
    label_mni[(label_npy==26)] = 4

    new_nii = nib.Nifti1Image(label_mni.astype(np.int64), label_nib.affine, label_nib.header)
    save_path = label_path.replace("seg35.nii.gz", "seg4_mni.nii.gz")
    nib.save(new_nii, save_path)

This is the code I am using, is there a problem with this? I'm really sorry to keep bothering you, and thank you for your reply, it's sincerely appreciated!

cwmok commented 3 months ago

@RexEmperor Could you send me an example of preprocessed fixed and moving image pair along with the corresponding preprocessed segmentation? I can help you to have a look.

RexEmperor commented 3 months ago

MNI152_T1_1mm_brain_pad_RSP.nii.gz MNI-maxprob-thr50-1mm_pad_RSP_oasis.nii.gz norm.nii.gz seg4_mni.nii.gz

imgs = sorted(glob.glob("/OASIS_OAS1_*_MR1/norm.nii.gz"))
labels = sorted(glob.glob("/OASIS_OAS1_*_MR1/seg4_mni.nii.gz"))
MNI152_img = "MNI152_T1_1mm_brain_pad_RSP.nii.gz"
MNI152_label = "MNI-maxprob-thr50-1mm_pad_RSP_oasis.nii.gz"

Thank you very much indeed.

cwmok commented 3 months ago

@RexEmperor Your preprocessed data is correct. It is the same as my training data. image

cwmok commented 3 months ago

@RexEmperor I believe I have spotted the issue. In my code, I use one-hot encoded label.

imgs = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/norm.nii.gz"))
labels = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/seg4_mni_onehot.nii.gz"))
MNI152_img = "../Data/MNI152_T1_1mm_brain_pad_RSP.nii.gz"
MNI152_label = "../Data/MNI-maxprob-thr50-1mm_pad_RSP_oasis_onehot.nii.gz"

Your label is not one-hot encoded. See https://www.geeksforgeeks.org/ml-one-hot-encoding/ for more detail about one-hot encoding.

RexEmperor commented 3 months ago

I read this page about one-shot coding and I learned a few things. However, I still don't understand how one-shotting should be done for .nii.gz files, because the entire .nii.gz file file should be tagged, right? So, I would like to ask if you can send me your one-shot file or show me the preprocessor code. Thank you very much.

cwmok commented 3 months ago

Hi @RexEmperor,

It is in https://github.com/cwmok/C2FViT/blob/main/Data/image_A_seg4_mni_onehot.nii.gz

The segmentation for fixed image will be in similar format.

RexEmperor commented 3 months ago

Sorry, I don't quite understand how to go about this. Can you send me the code that does the one-shot preprocessing of the data? Or the processed data (but it seems to be too big)?

RexEmperor commented 2 months ago

Or how can this one-shot encoded moving image you provided help?What do I do with it?Thank you.

cwmok commented 2 months ago

Hi @RexEmperor,

There are many ways to achieve one-hot encoding. For example, https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html.

You just need to inspect the provided the one-hot encoded image (np.shape, intensity etc) and preprocess the image in the same way as the provided one.

RexEmperor commented 2 months ago

Dear @cwmok Sorry,I'm sorry I'm having a problem. It seems to me that the image_A_seg4_mni_onehot.nii.gz file is the image_A_seg4_mni.nii.gz file encoded by onehot right? But I found through the code that the num_classes of image_A_seg4_mni.nii.gz is 5, but the size of the resulting image_A_seg4_mni_onehot.nii.gz is 256x256x256x4. This is really confusing to me. And, after converting the .nii.gz file to a tensor vector, it needs to be converted to long data before it can be input into the torch.nn.functional.one_hot function, do you do the same? Thank you.

cwmok commented 2 months ago

@RexEmperor

4 anatomical structures + 1 background = 5 classes.

We then take the 4 classes from the one-hot encoded data.

Yes, it accepts long data type only.

RexEmperor commented 2 months ago

You mean you only coded the categories and not the context, right? When I process the tensor, the system always reports an error:

img = nib.load('image_A.nii.gz')
data = img.get_fdata()
tensor = torch.from_numpy(data)
tensor = tensor.to(torch.long)
tensor_ontshot = torch.nn.functional.one_hot(tensor,4)

Am I handling this wrong? Thank you.

cwmok commented 2 months ago

@RexEmperor I used this script below to process the one-hot encoded data.

import glob

import nibabel as nib
import numpy as np

'''
MNI152 label - OASIS 35 label
0 - 
1 Caudate - 8, 27
2 Cerebellum - 5, 6, 24, 25
3 Frontal Lobe - None
4 Insula - None
5 Occipital Lobe - None
6 Parietal Lobe - None
7 Putamen - 9, 28
8 Temporal Lobe - None
9 Thalamus - 7, 26
'''

'''
OASIS seg35 list (selected 25 structures)

13    Brain-Stem                            119 159 176 0

7     Left-Thalamus                         0   118 14  0
26    Right-Thalamus                        0   118 14  0

6     Left-Cerebellum-Cortex                230 148 34  0
25    Right-Cerebellum-Cortex               230 148 34  0

3     Left-Lateral-Ventricle                120 18  134 0
22    Right-Lateral-Ventricle               120 18  134 0

5     Left-Cerebellum-White-Matter          220 248 164 0
24    Right-Cerebellum-White-Matter         220 248 164 0

9     Left-Putamen                          236 13  176 0
28    Right-Putamen                         236 13  176 0

8     Left-Caudate                          122 186 220 0
27    Right-Caudate                         122 186 220 0

10    Left-Pallidum                         12  48  255 0
29    Right-Pallidum                        12  48  255 0

14    Left-Hippocampus                      220 216 20  0
30    Right-Hippocampus                     220 216 20  0

11    3rd-Ventricle                         204 182 142 0

12    4th-Ventricle                         42  204 164 0

15    Left-Amygdala                         103 255 255 0
31    Right-Amygdala                        103 255 255 0

2     Left-Cerebral-Cortex                  205 62  78  0
21    Right-Cerebral-Cortex                 205 62  78  0

(too small) 19    Left-Choroid-Plexus                   0   200 200 0
(too small) 35    Right-Choroid-Plexus                  0   200 200 0

'''

print("Process MNI label...")

mni_label_path = "../Data/MNI-maxprob-thr50-1mm_pad_RSP_oasis.nii.gz"
mni_label = nib.load(mni_label_path)
mni_label_npy = mni_label.get_fdata()

print(f"Unique labels: {np.unique(mni_label_npy)}")

one_hot_img = np.zeros(mni_label_npy.shape + (4,), dtype=mni_label_npy.dtype)

for index, i in enumerate([1, 2, 3, 4]):
    one_hot_img[..., index] = (mni_label_npy == i)

new_img = nib.nifti1.Nifti1Image(one_hot_img.astype(np.int64), affine=mni_label.affine, header=mni_label.header)
save_path = mni_label_path.replace('MNI-maxprob-thr50-1mm_pad_RSP_oasis.nii.gz', 'MNI-maxprob-thr50-1mm_pad_RSP_oasis_onehot.nii.gz')
print("Saving... ", save_path)
nib.save(new_img, save_path)

OASIS_mni_path = sorted(glob.glob("../Data/OASIS/OASIS_OAS1_*_MR1/seg4_mni.nii.gz"))

for path in OASIS_mni_path:
    print(f"Processing {path}")
    label = nib.load(path)
    label_npy = label.get_fdata()

    print(f"Unique labels: {np.unique(label_npy)}")
    one_hot_img = np.zeros(label_npy.shape + (4,), dtype=label_npy.dtype)

    for index, i in enumerate([1, 2, 3, 4]):
        one_hot_img[..., index] = (label_npy == i)

    new_img = nib.nifti1.Nifti1Image(one_hot_img.astype(np.int64), affine=label.affine, header=label.header)
    save_path = path.replace('seg4_mni.nii.gz', 'seg4_mni_onehot.nii.gz')
    print("Saving... ", save_path)
    nib.save(new_img, save_path)

# print("Preprocessing OASIS label")
# left = [7, 6, 3, 5, 9, 8, 10, 14, 15, 2]
# right = [26, 25, 22, 24, 28, 27, 29, 30, 31, 21]
# left_right_pair = [list(a) for a in zip(left, right)]
# other = [11, 12, 13]
#
# OASIS_label_path = sorted(glob.glob("../Data/OASIS/OASIS_OAS1_*_MR1/seg35.nii.gz"))
#
# for path in OASIS_label_path:
#     print(f"Processing {path}")
#     label = nib.load(path)
#     label_npy = label.get_fdata()
#
#     print(f"Unique labels: {np.unique(label_npy)}")
#     one_hot_img = np.zeros(label_npy.shape + (13,), dtype=label_npy.dtype)
#
#     for index, pair in enumerate(left_right_pair):
#         left_image = (label_npy == pair[0])
#         right_image = (label_npy == pair[1])
#         one_hot_img[:, :, :, index] = left_image | right_image
#
#     for index, o in enumerate(other):
#         one_hot_img[:, :, :, index+10] = (label_npy == o)
#
#     new_img = nib.nifti1.Nifti1Image(one_hot_img.astype(np.int64), affine=label.affine, header=label.header)
#     save_path = path.replace('seg35.nii.gz', 'seg35_onehot.nii.gz')
#     print("Saving... ", save_path)
#     nib.save(new_img, save_path)
RexEmperor commented 1 month ago

image Sorry to bother you. I would like to ask about this table, this 0.72±0.06 means that if I reproduce this experiment by myself, even if I come up with an effect of 0.78 it is still reasonable right?