ciroraggio / AugmentedDataLoader

Medical image augmentation tool that can be integrated with Pytorch & MONAI.
MIT License
1 stars 0 forks source link

ImageToImageDataset #32

Closed jafarpof closed 9 months ago

jafarpof commented 9 months ago

Hi,

I am encountering an issue with the ImageToImageDataset class in the MONAI library, specifically regarding the shape of loaded images.

  1. Modification of __len__ Method: First, I noticed that I had to modify the __len__ method to: ''' def len(self) -> int: return len(self.first_type_image_files) '''

Second, Issue with Image Loading:

The main concern arises with the loading of images in the getitem method. The class uses the LoadImage transformer to load the images, and I added print statements to check the shape of the loaded images. However, instead of getting a 3D tensor shape, which I expected for medical images (like MRI or CT scans), the loaded images are in 2D shape. Here is the output I observed:

Loaded MR Image Shape: torch.Size([244, 204]) Loaded CT Image Shape: torch.Size([244, 204]) Loaded Mask Image Shape: torch.Size([244, 204])

I am trying to load and process 3D medical images, but the dataset seems to be returning 2D slices or incorrectly shaped tensors.

Could you please advise on how to resolve this issue? Is there a specific parameter or method I should use to ensure the images are correctly loaded as 3D tensors?

Thank you for your assistance.

ciroraggio commented 9 months ago

Hi, could you specify the transformations you are applying to the images? Attach your code snippet that shows how you are declaring the Dataset and DataLoader to reproduce the issue.

Thank you!

jafarpof commented 9 months ago

Sure , here you can find the snippet,in this stage I'm just trying to do resizing and doing a rotation to "SAC" with a predefined function, also my data structure was different compared to readme so I made a change in the load path. """" import os from monai.transforms import Compose, Resize from torch.utils.data import DataLoader from datasets.ImageToImageDataset import ImageToImageDataset

Define simplified preprocessing transformations

each_image_trans = Compose([ Resize(spatial_size=(256, 256, 256), mode='trilinear') # Use a smaller size for testing ])

root_path = "Data/extracted/brain" mr_paths, ct_paths, mask_paths = [], [], []

Use only 1 or 2 samples for testing

num_samples = 1 patient_folders = [f for f in sorted(os.listdir(root_path)) if not f.startswith('.')][:num_samples]

for patient_id in patient_folders: patient_folder = os.path.join(root_path, patient_id) if os.path.isdir(patient_folder): mr_image = os.path.join(patient_folder, "mr.nii.gz") ct_image = os.path.join(patient_folder, "ct.nii.gz") mask_image = os.path.join(patient_folder, "mask.nii.gz")

    if os.path.exists(mr_image) and os.path.exists(ct_image) and os.path.exists(mask_image):
        mr_paths.append(mr_image)
        ct_paths.append(ct_image)
        mask_paths.append(mask_image)

Initialize the dataset with simplified transformations

dataset = ImageToImageDataset( first_type_image_files=mr_paths, second_type_image_files=ct_paths, seg_files=mask_paths, first_type_image_transforms=each_image_trans, second_type_image_transforms=each_image_trans, seg_transform=each_image_trans, reader="nibabelreader" )

DataLoader with single worker for debugging

augmented_data_loader = DataLoader( dataset=dataset, batch_size=1, # Use a smaller batch size shuffle=False, num_workers=0, # Set to 0 for debugging pin_memory=True )

Test loading a batch

for batch in augmented_data_loader: augm_mr_image, augm_ct_image, augm_mask = batch print(f"MR Image Shape: {augm_mr_image.shape}") print(f"CT Image Shape: {augm_ct_image.shape}") print(f"Mask Shape: {augm_mask.shape}") break """"

The print statement for this code is :

""" Loaded MR Image Shape: torch.Size([244, 204]) Loaded CT Image Shape: torch.Size([244, 204]) Loaded Mask Image Shape: torch.Size([244, 204]) MR Image Shape: torch.Size([1, 231, 256, 256, 256]) CT Image Shape: torch.Size([1, 231, 256, 256, 256]) Mask Shape: torch.Size([1, 231, 256, 256, 256])""

Thanks.

ciroraggio commented 9 months ago

Hi, the dataset you are using (ImageToImageDataset) is designed to work with "AugmentedImageToImageDataLoader" (which you find in the "loaders" folder), specifically designed to exploit the logic of AugmentedDataLoader even if you have two images associated with a mask.

From the code you provided, it would appear that you are using the classic DataLoader provided by PyTorch. Therefore, I invite you to use AugmentedImageToImageDataLoader and check again.

See the "AugmentedImageToImageDataLoader" section in the readme for a usage example. Thank you!

jafarpof commented 9 months ago

Thanks for your reply , I switched to AugmentedImageToImageDataLoader,

But my original question was related to pre-transformation part, here : The main concern arises with the loading of images in the getitem method. The class uses the LoadImage transformer to load the images, and I added print statements to check the shape of the loaded images. However, instead of getting a 3D tensor shape, which I expected for medical images (like MRI or CT scans), the loaded images are in 2D shape. Here is the output I observed:

Loaded MR Image Shape: torch.Size([244, 204]) Loaded CT Image Shape: torch.Size([244, 204]) Loaded Mask Image Shape: torch.Size([244, 204])

I am trying to load and process 3D medical images, but the dataset seems to be returning 2D slices or incorrectly shaped tensors.

ciroraggio commented 9 months ago

Hi, I was not able to replicate your issue unfortunately. After loading images I still have 3D tensors. Try removing the transformation applied to all images and debug the code.

I attach the output of my debugger here: ImageToImageDataset. getitem line 70: first_img.size(): torch.Size([512, 512, 3]) second_img.size(): torch.Size([512, 512, 3]) ImageToImageDataset. getitem line 89 (no transformations applied): data[0].size(): torch.Size([512, 512, 3]) data[1].size(): torch.Size([512, 512, 3])

Please check your process further, thank you.

jafarpof commented 9 months ago

ct.nii.gz mask.nii.gz mr.nii.gz """ I appreciate your clarification , Would you mind if you also try it for a 3d medical image sample rather thatn RGB to see you if get a 3D shape tensor ? I attached the 3d data (MR,CT,Maks) here if you don't access to it , Thanks .

ciroraggio commented 9 months ago

Hi, thanks for the data.

issue-debug As you can see from the debugger, the images are loaded correctly. Try changing your reader to 'pilreader'.

Thank you.