MrGiovanni / ModelsGenesis

[MICCAI 2019 Young Scientist Award] [MEDIA 2020 Best Paper Award] Models Genesis
Other
737 stars 140 forks source link

Major refactoring and bugfixes to allow data preprocessed by nnUNet #49

Closed joeranbosma closed 2 years ago

joeranbosma commented 3 years ago

Refactoring:

Bugfixes for data preprocessed by nnUNet:

Cleanup:

joeranbosma commented 3 years ago

Hi @MrGiovanni,

Using the changes proposed in this git merge request, it is possible to set up a DataLoader with PyTorch using data from the preprocessed folder of nnUNet. This allows to re-use the same preprocessing scripts as you would use for your normal nnUNet runs.

DataLoader:

class CustomDatasetLoader(Dataset):
    def __init__(self, subject_list, in_dir_scans, config, target_shape=(3, 20, 160, 160)):
        self.subject_list = subject_list
        self.in_dir_scans = in_dir_scans
        self.config       = config
        self.target_shape = target_shape

    def __len__(self):
        return len(self.subject_list)

    def __getitem__(self, idx):
        subject_id = self.subject_list[idx]
        path = os.path.join(self.in_dir_scans, f"{subject_id}.npz")
        scans = np.load(path)['data']
        scans = scans[:-1]
        scans = compatibility_nnUNet(scans, target_size=self.target_shape)

        # apply Models Genesis data augmentation
        scans = np.moveaxis(scans, -3, -1)
        x, y = generate_single_pair(scans, self.config)
        x = np.moveaxis(x, -1, -3)
        y = np.moveaxis(y, -1, -3)
        return x.copy(), y.copy()

def compatibility_nnUNet(lbl, target_size=(20,160,160)):
    """Resize lbl to 20x160x160 for evaluation"""
    if isinstance(lbl, sitk.Image):
        lbl = sitk.GetArrayFromImage(lbl)

    if lbl.shape != target_size:
        lbl = resize_image_with_crop_or_pad(lbl, target_size)
    return lbl

# Resize/Cropping Image with Padding [Ref:DLTK]
def resize_image_with_crop_or_pad(image, img_size=(64, 64, 64), **kwargs):
    assert isinstance(image, (np.ndarray, np.generic))
    assert (image.ndim - 1 == len(img_size) or image.ndim == len(img_size)), \
        'Example size doesnt fit image size'

    # Get the image dimensionality
    rank = len(img_size)

    # Create placeholders for the new shape
    from_indices = [[0, image.shape[dim]] for dim in range(rank)]
    to_padding = [[0, 0] for dim in range(rank)]

    slicer = [slice(None)] * rank

    # For each dimensions find whether it is supposed to be cropped or padded
    for i in range(rank):
        if image.shape[i] < img_size[i]:
            to_padding[i][0] = (img_size[i] - image.shape[i]) // 2
            to_padding[i][1] = img_size[i] - image.shape[i] - to_padding[i][0]
        else:
            from_indices[i][0] = int(np.floor((image.shape[i] - img_size[i]) / 2.))
            from_indices[i][1] = from_indices[i][0] + img_size[i]

        # Create slicer object to crop or leave each dimension
        slicer[i] = slice(from_indices[i][0], from_indices[i][1])

    # Pad the cropped image to extend the missing dimension
    return np.pad(image[tuple(slicer)], to_padding, **kwargs)

train_data = CustomDatasetLoader(subject_list_train, in_dir_scans=args.in_dir_scans, config=config)
valid_data = CustomDatasetLoader(subject_list_valid, in_dir_scans=args.in_dir_scans, config=config)

train_dataloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=6)
valid_dataloader = DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=6)
MrGiovanni commented 2 years ago

Hi @joeranbosma

Thank you so much for your revision! It looks great and I've merged the commit.

Best,

Zongwei