Project-MONAI / research-contributions

Implementations of recent research prototypes/demonstrations using MONAI.
https://monai.io/
Apache License 2.0
995 stars 328 forks source link

Pretrained model on BTCV doesn't reproduce the mean dice score #79

Open zackchen-lb opened 2 years ago

zackchen-lb commented 2 years ago

Hi there,

The provided model weights for BTCV (swinunetr-base) can't reproduce the same mean dice score on validation set. I only get a mean dice score around 0.16~0.2 which is far less than the given 0.8.

Basically I used the google colab codes as following:

with torch.no_grad():
    dice_list_case = []
    for i, batch in enumerate(val_loader):
        val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
        # original_affine = batch['label_meta_dict']['affine'][0].numpy()
        _, _, h, w, d = val_labels.shape
        target_shape = (h, w, d)
        # img_name = batch['image_meta_dict']['filename_or_obj'][0].split('/')[-1]
        # print("Inference on case {}".format(img_name))
        val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=0.5, mode="gaussian")
        val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
        val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
        val_labels = val_labels.cpu().numpy()[0, 0, :, :, :]
        val_outputs = resample_3d(val_outputs, target_shape)
        dice_list_sub = []
        for i in range(1, 14):
            organ_Dice = dice(val_outputs == i, val_labels == i)
            dice_list_sub.append(organ_Dice)
        mean_dice = np.mean(dice_list_sub)
        print("Mean Organ Dice: {}".format(mean_dice))
        dice_list_case.append(mean_dice)
        # nib.save(nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine),
        #             os.path.join(output_directory, img_name))

    print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))

The model has been loaded from the pretrained weights you provided as below and data transformation and data loader are set exactly the same as provided:

Name Dice (overlap=0.7) Dice (overlap=0.5) Feature Size # params (M) Self-Supervised Pre-trained Download
Swin UNETR/Base 82.25 81.86 48 62.1 Yes model
Swin UNETR/Small 79.79 79.34 24 15.7 No model
Swin UNETR/Tiny 72.05 70.35 12 4.0 No model

I wonder if I actually missed anything here, I appreciate for your feedback! Thanks.

tangy5 commented 2 years ago

HI @ZEKAICHEN , thanks for raising the issue. We've double checked and re-run the test.py using the code. If the used code is from https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR/BTCV and use the Swin UNETR/Base model downloaded. It should give us the Dice score as below using overlap0.5:

Inference on case img0035.nii.gz
Mean Organ Dice: 0.7715836852979835
Inference on case img0036.nii.gz
Mean Organ Dice: 0.8377579306350628
Inference on case img0037.nii.gz
Mean Organ Dice: 0.8386162560902106
Inference on case img0038.nii.gz
Mean Organ Dice: 0.7809781930534572
Inference on case img0039.nii.gz
Mean Organ Dice: 0.8375578949580794
Inference on case img0040.nii.gz
Mean Organ Dice: 0.8275152177091785
Overall Mean Dice: 0.815668196290662

Could you provide more detailed of your implementation of the testing, we can help dig deep to the problem. Thanks!