FirasGit / medicaldiffusion

Medical Diffusion: This repository contains the code to our paper Medical Diffusion: Denoising Diffusion Probabilistic Models for 3D Medical Image Synthesis
355 stars 62 forks source link

Persistent "White Dot" Artifacts in VQGAN Reconstruction #20

Open junzhin opened 1 month ago

junzhin commented 1 month ago

Problem Description

During the training of the VQGAN model on CT data, persistent "white dot" artifacts appear in the reconstruction results, even after extensive training. The white dots are present in both the training and validation sets, and although their frequency decreases over time, they are still clearly visible after 300K iterations. We have tried adjusting several training parameters, but the issue persists.

Request for Input

Could you take a look at this issue? We're wondering if this could be related to the data preprocessing pipeline, particularly the resampling step, or if there might be another factor we're overlooking.

image

image

image image

class DEFAULTDataset(Dataset):
    def __init__(self, root_dir, size=None, mode = 'train', **others):
        print('size: ', size)
        super().__init__()
        self.size = size

        if others["resize"] is False:
            self.preprocessing = tio.Compose([
                tio.RescaleIntensity(out_min_max=(-1, 1)),
                tio.CropOrPad(target_shape=size)
            ])
        else:
            self.preprocessing = tio.Compose([
                tio.Lambda(self.resample_to_target_shape),
                tio.RescaleIntensity(out_min_max=(-1, 1)),
                tio.CropOrPad(target_shape=size)
            ])

        self.target_shape = size

        self.mode = mode
        if self.mode == "train":
            self.transforms = TRAIN_TRANSFORMS
        else:
            self.transforms = None
        root_dir = list(root_dir)   
        print('root_dir: ', root_dir)         
        print('root_dir type: ', type(root_dir))
        self.file_paths = self.get_data_files(root_dir)

    def resample_to_target_shape(self, image):  
        if isinstance(image, tio.ScalarImage):
            original_shape = image.spatial_shape
            original_voxel_size = image.spacing
        else: 
            image = tio.ScalarImage(tensor=image)
            original_shape = image.spatial_shape
            original_voxel_size = [1, 1, 1]   

        # Compute the new voxel size to give the desired target shape
        target_voxel_size = [
            (orig_size * voxel_size) / target_size
            for orig_size, voxel_size, target_size in zip(original_shape, original_voxel_size, self.target_shape)
        ]

        # Create a Resample transform using the computed voxel size
        resample_transform = tio.Resample(target_voxel_size)

        # Apply the resample transform to the image (if image is not Tensor)
        if isinstance(image, tio.ScalarImage):
            return resample_transform(image).data
        else:
            raise TypeError("Something wrong about this image data, it should be ScalarImage datatype so it can be transformed using Resampling operations.")

We are using the same config described in the paper. image

Any guidance or suggestions would be appreciated. Thanks!

lukas-folle-snkeos commented 3 days ago

Looks to me like a clipping error of the visualization code. You could check the min/max of the image and try to push it into a range that can be handled by the plotting function.