SuperMedIntel / Medical-SAM-Adapter

Adapting Segment Anything Model for Medical Image Segmentation
GNU General Public License v3.0
1.02k stars 90 forks source link

Clarity on brain tumor results #109

Open 25benjaminli opened 6 months ago

25benjaminli commented 6 months ago

Hello, thanks for the pre-print and releasing the code. I am working on brain tumor segmentation with the BraTS dataset, which is semantic, as you know. I have two questions that would be nice to have some clarity on.

  1. What is actually being evaluated in the pre-print? The objectives seem like they're binary, such as the dataset dataset/brat.py seeming to show it doing binary segmentation (segmentation mask is binarized). I know Segment Anything is natively a binary segmentation algorithm, but how come there are class-wise results for the BTCV organ dataset and not BraTS? @LJQCN101 supposedly integrated multimask output for semantic segmentation, but the results in the pre-print don't seem to reflect this. Code segment from dataset/brat.pyattached below.

    def __getitem__(self, index):
        # if self.mode == 'Training':
        #     point_label = random.randint(0, 1)
        #     inout = random.randint(0, 1)
        # else:
        #     inout = 1
        #     point_label = 1
        point_label = 1
        label = 4   # the class to be segmented
    
        """Get the images"""
        name = self.name_list[index]
        img,mask = self.load_all_levels(name)
    
        mask[mask!=label] = 0
        mask[mask==label] = 1
  2. Does medical SAM by default use channel wise segmentation with multiple modalities (e.g. in the case of brain tumors, flair, t1, t1ce, t2) or does it repeat the same modality across multiple channels? I ask this because in the dataset it shows only the first level being used. Code segment from dataset/brat.py attached below.

    def load_all_levels(self,path):
        import nibabel as nib
        data_dir = os.path.join(self.data_path)
        levels = ['t1','flair','t2','t1ce']
        raw_image = [nib.load(os.path.join
        (data_dir,path,path+'_'+level+'.nii.gz')).get_fdata() for level in levels]
        raw_seg = nib.load(os.path.join(data_dir,path,path+'_seg.nii.gz')).get_fdata()
    
        return raw_image[0], raw_seg

I am adding the authors of the pre-print and the person who implemented multimask output below. Thanks again! @WuJunde @LJQCN101

EDIT: when trying to run with the default "Brat.py" dataset configuration, I ran into the following issue: Given groups =1, weight of size [768, 3,16,16] expected input [1,1024,1024,155] to have 3 channels, but got 1024 channels instead