med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
149 stars 13 forks source link

Segmentation results and "train Vs train_auto" #27

Open carlotita22 opened 1 year ago

carlotita22 commented 1 year ago

Hi! how can I see the results of the segmentations, I am adding a code inside the model_predict function (in test_py):

for i, prediction in enumerate(seg_pred):
            nifti_img = nib.Nifti1Image(np.zeros_like(prediction), affine=None)    
             nifti_img.get_data()[:] = prediction.astype(np.float32)   
            output_filename = f'prediction_{i}.nii.gz'  
            nib.save(nifti_img, 'nifti_mask_test')

But I am not clear if seg_pred corresponds to the final mask predicted by the network. But I think I'm wrong, because I get the following error. How can I do it?

[22:12:25.424] Namespace(data='myocardium', snapshot_path='path/to/snapshot/myocardium', data_prefix='path/to/data folder/', rand_crop_size=(128, 128, 128), device='cuda:0', num_prompts=1, batch_size=1, num_classes=2, num_worker=6, checkpoint='last', tolerance=5) Traceback (most recent call last): File "/mnt/workspace/cgrivera/3DSAM-adapter/3DSAM-adapter/3DSAM-adapter/test.py", line 301, in main() File "/mnt/workspace/cgrivera/3DSAM-adapter/3DSAM-adapter/3DSAM-adapter/test.py", line 118, in main torch.load(os.path.join(args.snapshot_path, file), map_location='cpu')["feature_dict"][i], strict=True) KeyError: 'feature_dict'.

Thanks in advance, Regards!

carlotita22 commented 12 months ago

up U.u

peterant330 commented 12 months ago

Hi,

According to your error message, it seems the error has nothing to do with the code you added. It is because the pre-trained checkpoint you saved during training has a different format than that you used during testing. Looks like you are using train_auto.py to train while test.py to test so that the prompt encoder is not contained in your checkpoint while your inference needs this.