Amshaker / unetr_plus_plus

[IEEE TMI-2024] UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
Apache License 2.0
340 stars 32 forks source link

The accuracy of the reproduced Lung dataset is very different from the original paper #35

Closed smanman closed 1 year ago

smanman commented 1 year ago

Hello, I reproduce the Lung dataset accuracy is only 73 points, which is very different from the original paper, can you provide the hyperparameter values of the training Lung dataset

Amshaker commented 1 year ago

Hi @6018203135 , Could you let me know how you did the evaluation? According to the provided data division, model architecture, and the checkpoint, I verified that the evaluation accuracy is the same as the paper (80.68%). It seems that you did something incorrect.

Amshaker commented 1 year ago

I uploaded the checkpoint without compression here:

Could you please use this checkpoint for evaluation?

Here is the evaluation results using the uploaded checkpoint:

smanman commented 1 year ago

This is my file import glob import os import SimpleITK as sitk import numpy as np from medpy.metric import binary from sklearn.neighbors import KDTree from scipy import ndimage import argparse

def read_nii(path): itk_img=sitk.ReadImage(path) spacing=np.array(itk_img.GetSpacing()) return sitk.GetArrayFromImage(itk_img),spacing

def dice(pred, label): if (pred.sum() + label.sum()) == 0: return 1 else: return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum())

def process_label(label): cancer = label == 1 return cancer

def hd(pred,gt): pred[pred > 0] = 1 gt[gt > 0] = 1 if pred.sum() > 0 and gt.sum()>0: dice = binary.dc(pred, gt) hd95 = binary.hd95(pred, gt) return dice, hd95 elif pred.sum() > 0 and gt.sum()==0: return 1, 0 else: return 0, 0 '''

def hd(pred, gt):

labelPred=sitk.GetImageFromArray(lP.astype(np.float32), isVector=False)

# labelTrue=sitk.GetImageFromArray(lT.astype(np.float32), isVector=False)
# hausdorffcomputer=sitk.HausdorffDistanceImageFilter()
# hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5)
# return hausdorffcomputer.GetAverageHausdorffDistance()
if pred.sum() > 0 and gt.sum() > 0:
    hd95 = binary.hd95(pred, gt)
    return hd95
    return 0

def test(fold): path = '/home/###################/data/unetr_pp_Data/DATASET/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung' label_path = '/home/###################/data/unetr_pp_Data/DATASET/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung/labelsTs' pred_path = '/home/###################/data/unetr_pp_Data/DATASET/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung/infersTs'

label_list = sorted(glob.glob(os.path.join(label_path, '*nii.gz')))
infer_list = sorted(glob.glob(os.path.join(pred_path, '*nii.gz')))

print("loading success...")
Dice_cancer = []

hd_cancer = []

file = path + 'inferTs/' + fold
if not os.path.exists(file):
fw = open(file + '/dice_pre.txt', 'w')

for label_path, infer_path in zip(label_list, infer_list):
    label, spacing = read_nii(label_path)
    infer, spacing = read_nii(infer_path)
    label_cancer = process_label(label)
    infer_cancer = process_label(infer)

    Dice_cancer.append(dice(infer_cancer, label_cancer))

    hd_cancer.append(hd(infer_cancer, label_cancer))

    fw.write('*' * 20 + '\n', )
    fw.write(infer_path.split('/')[-1] + '\n')
    fw.write('hd_cancer: {:.4f}\n'.format(hd_cancer[-1]))

    # fw.write('*'*20+'\n')
    fw.write('*' * 20 + '\n', )
    fw.write(infer_path.split('/')[-1] + '\n')
    fw.write('Dice_cancer: {:.4f}\n'.format(Dice_cancer[-1]))
    fw.write('*' * 20 + '\n')

# fw.write('*'*20+'\n')
# fw.write('Mean_hd\n')
# fw.write('hd_rv'+str(np.mean(hd_rv))+'\n')
# fw.write('hd_myo'+str(np.mean(hd_myo))+'\n')
# fw.write('hd_lv'+str(np.mean(hd_lv))+'\n')
# fw.write('*'*20+'\n')

fw.write('*' * 20 + '\n')
fw.write('Dice_cancer' + str(np.mean(Dice_cancer)) + '\n')

fw.write('HD_cancer' + str(np.mean(hd_cancer)) + '\n')

fw.write('*' * 20 + '\n')

dsc = []

avg_hd = []

fw.write('avg_hd:' + str(np.mean(avg_hd)) + '\n')

fw.write('DSC:' + str(np.mean(dsc)) + '\n')
fw.write('HD:' + str(np.mean(avg_hd)) + '\n')


if name == 'main': parser = argparse.ArgumentParser() parser.add_argument("fold", help="fold name") args = parser.parse_args() fold = args.fold test(fold)

smanman commented 1 year ago

After I train, I use the files under the generated validation_raw folder as files under infersTs and real labelsTs to do dice calculations (i.e. files)

smanman commented 1 year ago

I did not evaluate the direct after training

Amshaker commented 1 year ago

@6018203135 You should use the generated files under validation_raw_postprocessed.

smanman commented 1 year ago

So that's it, thank you very much

Amshaker commented 1 year ago

@6018203135 You are most welcome!