StanfordMIMI / DDM2

[ICLR2023] Official repository of DDM2: Self-Supervised Diffusion MRI Denoising with Generative Diffusion Models
125 stars 20 forks source link

Question regarding dataset #22

Open doro041 opened 3 months ago

doro041 commented 3 months ago

Has that dataset be applied to another type of MRI?

The MRI I am working on is with dimensions 90,90,3,5 - magnetic field x evolution time? I tried to combine them as I did for Patch2Self which seemed to work.

I get this error : export CUDA_VISIBLE_DEVICES=0 24-03-10 19:49:07.449 - INFO: [Phase 1] Training noise model! Loaded data of size: (90, 90, 1, 15) Traceback (most recent call last): File "/Users/dolorious/Desktop/MLMethods/DDM2/train_noise_model.py", line 42, in train_set = Data.create_dataset(dataset_opt, phase) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dolorious/Desktop/MLMethods/DDM2/data/init.py", line 30, in create_dataset dataset = MRIDataset(dataroot=dataset_opt['dataroot'], ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dolorious/Desktop/MLMethods/DDM2/data/mri_dataset.py", line 35, in init self.raw_data = np.pad(raw_data.astype(np.float32), ((0,0), (0,0), (in_channel//2, in_channel//2), (self.padding, self.padding)), mode='wrap').astype(np.float32) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dolorious/.pyenv/versions/3.12.0/lib/python3.12/site-packages/numpy/lib/arraypad.py", line 819, in pad raise ValueError( ValueError: can't extend empty axis 3 using modes other than 'constant' or 'empty' (ddm2) dolorious@bigDPotter DDM2 %

tiangexiang commented 3 months ago

Hi, thanks for your interest in our work! I think the issue comes from data loading. It seems that for some reason, your data has been loaded as 90 x 90 x 1 x 15, not really 90 x 90 x 3 x 5 as you specified. The padding failed when there is only one slice in each volume. After it gets loaded in the correct shape, I think the code can work. However, I do have concerns about the final denoising quality though, given the number of slices is too small (only 3). Examples used in the paper usually have > 30 slices in each volume.

doro041 commented 3 months ago

I appreciate the prompt response. I think my original data is 90x90x5x3 - evolution time x magnetic field. I understand that my data is not the same , I was wondering of any suggestions I could get to tailor it to my implementation, should I preprocess my data in a way to be accepted for torch for instance - batch,channel, height, width or is there anything else I could look into? I know that slicing is only 5 and that this might hinder results,but I am trying to tailor that ML method and compare it to Patch2Self which worked very well after some initial preprocessing! :)

I guess one question I would like to ask is within the hardi sample , how many images are there? In my case its a single image, but I want to go through multiple gaussian-noise level images with nii.gz format?

from curses import raw from io import BytesIO from PIL import Image from torch.utils.data import Dataset import random import os import numpy as np import torch from dipy.io.image import save_nifti, load_nifti from matplotlib import pyplot as plt from torchvision import transforms, utils

import matplotlib

from torch.utils.data import Dataset import numpy as np import torch from dipy.io.image import load_nifti from torchvision import transforms

class MRIDataset(Dataset): def init(self, dataroot, valid_mask, phase='train', image_size=128, in_channel=1, val_volume_idx=50, val_slice_idx=40, padding=1, lr_flip=0.5, stage2_file=None): self.dataroot = dataroot self.valid_mask = valid_mask self.phase = phase self.image_size = image_size self.in_channel = in_channel self.val_volume_idx = val_volume_idx self.val_slice_idx = val_slice_idx self.padding = padding self.lr_flip = lr_flip self.stage2_file = stage2_file

    self.raw_data, _ = load_nifti(dataroot)
    print(f'Loaded data of size: {self.raw_data.shape}')

    self.raw_data = self.raw_data.astype(np.float32) / np.max(self.raw_data)
    if isinstance(valid_mask, (list, tuple)) and len(valid_mask) == 2:
        self.raw_data = self.raw_data[:, :, valid_mask[0]:valid_mask[1]]
        print(f"Using valid_mask slices from {valid_mask[0]} to {valid_mask[1]}")
    else:
        print("valid_mask is not a slice range")

    self.data_size_before_padding = self.raw_data.shape
    self.raw_data = np.pad(self.raw_data, ((0, 0), (0, 0), (0, 0), (0, self.padding)), mode='constant', constant_values=0)

    print(f'Data after padding: {self.raw_data.shape}')

    self.transforms = self.build_transforms()

    if self.stage2_file:
        self.matched_state = self.parse_stage2_file(self.stage2_file)
    else:
        self.matched_state = None

def build_transforms(self):
    if self.phase == 'train':
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomVerticalFlip(p=self.lr_flip),
            transforms.RandomHorizontalFlip(p=self.lr_flip),
            transforms.Lambda(lambda t: (t * 2) - 1)
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
        ])

def __len__(self):
    return self.data_size_before_padding[-1] * (self.data_size_before_padding[-2] + self.padding)

def __getitem__(self, idx):
    volume_idx = idx // self.data_size_before_padding[-2]
    slice_idx = idx % self.data_size_before_padding[-2]

    raw_input = self.raw_data[..., slice_idx]
    if self.padding > 0:
        raw_input = np.pad(raw_input, ((0, 0), (0, self.padding)), 'constant', constant_values=0)

    raw_input = self.transforms(raw_input)

    print(f'Fetched item at volume {volume_idx}, slice {slice_idx} with shape {raw_input.shape}')

    if self.matched_state:
        matched_state = self.matched_state.get((volume_idx, slice_idx), 0)
        return {'X': raw_input, 'matched_state': matched_state}
    else:
        return {'X': raw_input}

def parse_stage2_file(self, file_path):
    results = {}
    with open(file_path, 'r') as f:
        for line in f:
            volume_idx, slice_idx, state = map(int, line.strip().split('_'))
            results[(volume_idx, slice_idx)] = state
    return results

if name == "main": valid_mask = np.zeros(160,) valid_mask[10:] += 1 # Assuming you wish to include slices from index 10 onwards valid_mask = valid_mask.astype(np.bool8)

# Corrected indentation below and added print statement for debugging
print(f"Creating dataset with {np.sum(valid_mask)} valid slices included.")
dataset = MRIDataset('gaussian_snr_original_1.nii.gz', valid_mask,
                     phase='train', val_volume_idx=40, padding=3)

# Print the length of the dataset to verify correct loading
print("Length of the dataset:", len(dataset))

trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

# Iterate through a subset of the DataLoader to visualize the data
for i, data in enumerate(trainloader):
    if not (95 <= i <= 108):  # Corrected the logic here to only process indices between 95 and 108
        continue

    img = data['X']
    condition = data['condition']
    img = img.numpy()
    condition = condition.numpy()

    # Visualization logic remains the same
    vis = np.hstack((img[0].transpose(1,2,0), condition[0,[0]].transpose(1,2,0), condition[0,[1]].transpose(1,2,0)))
    plt.imshow(img[0].transpose(1,2,0), cmap='gray')
    plt.show()
    plt.imshow(condition[0,[0]].transpose(1,2,0), cmap='gray')
    plt.show()
    plt.imshow(condition[0,[1]].transpose(1,2,0), cmap='gray')
    plt.show()

    plt.imshow(vis, cmap='gray')
    plt.show()

2.2.1 None export CUDA_VISIBLE_DEVICES=0 24-03-11 22:45:16.720 - INFO: [Phase 1] Training noise model! Loaded data of size: (90, 90, 5, 3) Using valid_mask slices from 65 to 129 Data after padding: (90, 90, 0, 6) 24-03-11 22:45:17.817 - INFO: MRI dataset [simulation] is created. Loaded data of size: (90, 90, 5, 3) Using valid_mask slices from 65 to 129 Data after padding: (90, 90, 0, 6) 24-03-11 22:45:17.827 - INFO: MRI dataset [s3sh] is created. 24-03-11 22:45:17.827 - INFO: Initial Dataset Finished Traceback (most recent call last): File "/Users/dolorious/Desktop/MLMethods/DDM2/train_noise_model.py", line 53, in trainer = Model.create_noise_model(opt) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dolorious/Desktop/MLMethods/DDM2/model/init.py", line 13, in create_noise_model m = M(opt) ^^^^^^ File "/Users/dolorious/Desktop/MLMethods/DDM2/model/model_stage1.py", line 34, in init image_size=opt['model']['diffusion']['image_size'],


TypeError: 'NoneType' object is not subscriptable
(base) dolorious@bigDPotter DDM2 %