med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
170 stars 12 forks source link

Segmentation results and "train Vs train_auto" #27

Open carlotita22 opened 1 year ago

carlotita22 commented 1 year ago

Hi! how can I see the results of the segmentations, I am adding a code inside the model_predict function (in test_py):

for i, prediction in enumerate(seg_pred):
            nifti_img = nib.Nifti1Image(np.zeros_like(prediction), affine=None)    
             nifti_img.get_data()[:] = prediction.astype(np.float32)   
            output_filename = f'prediction_{i}.nii.gz'  
            nib.save(nifti_img, 'nifti_mask_test')

But I am not clear if seg_pred corresponds to the final mask predicted by the network. But I think I'm wrong, because I get the following error. How can I do it?

[22:12:25.424] Namespace(data='myocardium', snapshot_path='path/to/snapshot/myocardium', data_prefix='path/to/data folder/', rand_crop_size=(128, 128, 128), device='cuda:0', num_prompts=1, batch_size=1, num_classes=2, num_worker=6, checkpoint='last', tolerance=5) Traceback (most recent call last): File "/mnt/workspace/cgrivera/3DSAM-adapter/3DSAM-adapter/3DSAM-adapter/test.py", line 301, in main() File "/mnt/workspace/cgrivera/3DSAM-adapter/3DSAM-adapter/3DSAM-adapter/test.py", line 118, in main torch.load(os.path.join(args.snapshot_path, file), map_location='cpu')["feature_dict"][i], strict=True) KeyError: 'feature_dict'.

Thanks in advance, Regards!

carlotita22 commented 1 year ago

up U.u

peterant330 commented 1 year ago

Hi,

According to your error message, it seems the error has nothing to do with the code you added. It is because the pre-trained checkpoint you saved during training has a different format than that you used during testing. Looks like you are using train_auto.py to train while test.py to test so that the prompt encoder is not contained in your checkpoint while your inference needs this.