junyuchen245 / TransMorph_Transformer_for_Medical_Image_Registration

TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
MIT License
437 stars 73 forks source link

about nii convert to pkl #31

Closed DongChunL closed 10 months ago

DongChunL commented 2 years ago

I used two NII medical images to convert them into a PKL file, but the loss reached tens of thousands. I guess there was a problem during the conversion, may I ask whether the author has experience in nii-PKL conversion? Can you tell me how to convert them in this code?thanks

junyuchen245 commented 1 year ago

Hi @DongChunL ,

A PyTorch dataset script that works directly on nii images will be provided at a later time.

Thanks, Junyu

TahaRazzaq commented 1 year ago

Hi @junyuchen245, Any update regarding the PyTorch dataset script for nii images?

Thanks.

junyuchen245 commented 11 months ago

Hi @TahaRazzaq ,

We have a script for loading the IXI NII dataset (details below). However, we're in the midst of testing our methods on multi-modal images and haven't decided to release the preprocessed dataset at this time. We anticipate making the preprocessed dataset available post the publication of our upcoming paper.

-Junyu

class IXIDataset(Dataset):
    def __init__(self, phase, modality, transform=None):
        self.modality = modality
        self.transform = transform
        self.data_path = 'path/to/dataset/'
        self.phase = phase

        self.imgs = sorted(glob(os.path.join(self.data_path, self.phase, self.modality, '*_T1.nii.gz')))

        self.imgs = [(
                x,
                x.replace(self.modality, 'masks').replace('.nii.gz', '_mask.nii.gz'),
                x.replace(self.modality, 'labels').replace('.nii.gz', '_seg.nii.gz'),
            ) for x in self.imgs]

        atlas_path = os.path.join(self.data_path, 'atlas', 'atlas_img.nii.gz')
        self.atlas = nib.load(atlas_path).get_fdata()
        self.atlas_mask = nib.load(atlas_path.replace('img', 'mask')).get_fdata()
        self.atlas = self.atlas * self.atlas_mask # apply the atlas brain mask
        self.atlas_seg = nib.load(atlas_path.replace('img', 'seg')).get_fdata()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, j):
        sample = {}

        sample['f_img'] = nib.load(self.imgs[j][0]).get_fdata()
        # apply the brain mask
        sample['f_img'] *= nib.load(self.imgs[j][1]).get_fdata()
        # sample['f_mask'] = nib.load(self.imgs[j][1]).get_fdata()
        # sample['f_img'] = sample['f_img'] * sample['f_mask']
        sample['f_mod'] = 'MR'

        sample['m_img'] = np.copy(self.atlas)
        # sample['m_mask'] = np.copy(self.atlas_mask)
        sample['m_mod'] = 'MR'

        if self.phase == 'train':
            #if randint(0, 1):
                # randomly swap imgs and moving during training
            #    sample['f_img'], sample['m_img'] = sample['m_img'], sample['f_img']
            #    sample['f_mod'], sample['m_mod'] = sample['m_mod'], sample['f_mod']
                # sample['f_mask'], sample['m_mask'] = sample['m_mask'], sample['f_mask']
            # create some dummy
            sample['f_seg'] = np.zeros_like(sample['f_img'])
            sample['m_seg'] = np.zeros_like(sample['m_img'])
        else:
            sample['f_seg'] = nib.load(self.imgs[j][2]).get_fdata()
            sample['m_seg'] = np.copy(self.atlas_seg)
            sample['f_basename'] = os.path.basename(self.imgs[j][0])
            sample['m_basename'] = os.path.basename(self.imgs[j][0])

        sample = self.apply_transform(sample)

        return sample

    def apply_transform(self, sample):
        if self.transform:
            return self.transform(sample)
        else:
            return sample