Open anar-rzayev opened 5 months ago
Hi, thank you for your interest!
Let me clarify the data dimensions first. Given a 4D MRI data [W, H, Z, T], in which [W, H, Z] are the dimensions to describe the MRI scan in the 3D space representing the length at the three (x,y,z) axes. For a 4D MRI scan with multiple acquisitions, the additional 4th dimension T indicates the number of observations of the same 3D structure. Due to the randomness in each acquisition, each 3D observation may be noisy in a different way. Denoising methods (including DDM2) are designed to try to learn consistency across those different noisy observations to recover a clean 3D structure.
For the choice of val_volume _idx = 40
: this is a random choice. You may choose any volume to validate without affecting the training process. And the choice ofvalid_mask = [10, 160]
: this indicates which interval of observations are NOT with a b-value = 0. For the hardi150 dataset, the first 10 observations have b-values = 0, therefore we exclude them from the training process and mask them using the valid_mask = [10, 160]
(meaning using 160-10 observations at the Tdimension).
Back to your problem, I think it still makes sense to set the same parameters for val_volume _idx
and valid_mask
. The problem indicates the raw data you loaded seems to only have T=8 observations instead of T=56. I suggest to double check the data first and make sure the loaded raw data has the expected data size.
Thanks so much, @tiangexiang, for your fast response to this issue and detailed explanations on valid_mask
& val_volume_idx
. As you instructed, I changed the .json file
and mri_dataset.py
to have valid_mask = [20, 56]
to capture non-zero b_val volumes from my dataset.
Now, after solving a minor bug as in the following,
24-01-14 20:28:08.899 - INFO: [Phase 1] Training noise model!
24-01-14 20:28:10.231 - INFO: MRI dataset [hardi] is created.
24-01-14 20:28:13.083 - INFO: MRI dataset [hardi] is created.
24-01-14 20:28:13.083 - INFO: Initial Dataset Finished
24-01-14 20:28:13.463 - INFO: Noise Model is created.
24-01-14 20:28:13.463 - INFO: Initial Model Finished
2.1.2+cu121 12.1
export CUDA_VISIBLE_DEVICES=0
Loaded data of size: (118, 118, 25, 56)
Loaded data of size: (118, 118, 25, 56)
dropout 0.0 encoder dropout 0.0
Traceback (most recent call last):
File "/home/anar/DDM2/train_noise_model.py", line 72, in <module>
trainer.optimize_parameters()
File "/home/anar/DDM2/model/model_stage1.py", line 62, in optimize_parameters
outputs = self.netG(self.data)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/anar/DDM2/model/mri_modules/noise_model.py", line 44, in forward
return self.p_losses(x, *args, **kwargs)
File "/home/anar/DDM2/model/mri_modules/noise_model.py", line 36, in p_losses
x_recon = self.denoise_fn(x_in['condition'])
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/anar/DDM2/model/mri_modules/unet.py", line 286, in forward
x = layer(x)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (double) and bias type (float) should be the same
I added a new line raw_data = raw_data.astype(np.float32)
after loading NIFTI data to solve this double-float error and then found another issue from training stage1:
24-01-14 20:29:55.122 - INFO: [Phase 1] Training noise model!
24-01-14 20:29:56.211 - INFO: MRI dataset [hardi] is created.
24-01-14 20:29:56.892 - INFO: MRI dataset [hardi] is created.
24-01-14 20:29:56.892 - INFO: Initial Dataset Finished
24-01-14 20:29:57.252 - INFO: Noise Model is created.
24-01-14 20:29:57.252 - INFO: Initial Model Finished
24-01-14 20:31:55.243 - INFO: <epoch: 35, iter: 1,000> l_pix: 3.8268e-03
24-01-14 20:33:51.252 - INFO: <epoch: 69, iter: 2,000> l_pix: 3.1790e-03
2.1.2+cu121 12.1
export CUDA_VISIBLE_DEVICES=0
Loaded data of size: (118, 118, 25, 56)
Loaded data of size: (118, 118, 25, 56)
dropout 0.0 encoder dropout 0.0
Validation
Traceback (most recent call last):
File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
for _, val_data in enumerate(val_loader):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/DDM2/data/mri_dataset.py", line 128, in __getitem__
raw_input = raw_input[:,:,0]
IndexError: index 0 is out of bounds for axis 2 with size 0
Do you have any ideas on how to possibly solve this issue? I have tried some modifications in mri_dataset.py
but neither of them solves the last IndexError.
Hi, it seems the error comes from the shape of the raw_input. Can you make sure the tensor raw_input
is in the shape [W, H, -1], or [W, H, 1, -1] before that line of code?
To double-check the shapes, I added a few lines in mri_dataset.py
:
# w, h, c, d = raw_input.shape
# raw_input = np.reshape(raw_input, (w, h, -1))
print("raw_input shape before slicing:", raw_input.shape)
if len(raw_input.shape) == 4:
raw_input = raw_input[:,:,0]
print("raw_input shape after slicing:", raw_input.shape)
raw_input = self.transforms(raw_input) # only support the first channel for now
# raw_input = raw_input.view(c, d, w, h)
And in the results, I get as follows:
raw_input shape before slicing: (118, 118, 1, 3)
raw_input shape after slicing: (118, 118, 3)
Validation
Traceback (most recent call last):
File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
for _, val_data in enumerate(val_loader):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/DDM2/data/mri_dataset.py", line 129, in __getitem__
raw_input = raw_input[:,:,0]
IndexError: index 0 is out of bounds for axis 2 with size 0
it seems that the data loaded for training is good, but the data loaded for validation may not be in the right shape. In any forms, the problem must come from how you load the data. I am not able to provide meaningful suggestions without more information, I still suggest to inspect the tensor shape at all possible locations.
BTW, for the previous issue case, I even added raw_input = np.reshape(raw_input, (raw_input.shape[0], raw_input.shape[1], 1, -1))
before raw_input = raw_input[:,:,0]
but still, the following error comes to play
raw_input shape before slicing: (118, 118, 1, 3)
raw_input shape after slicing: (118, 118, 3)
Validation
Traceback (most recent call last):
File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
for _, val_data in enumerate(val_loader):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anar/DDM2/data/mri_dataset.py", line 135, in __getitem__
ret = dict(X=raw_input[[-1], :, :], condition=raw_input[:-1, :, :])
IndexError: index is out of bounds for dimension with size 0
This is so weird, how come the training load is successful but validation fails?! My dataset is simply dwi_combined.nii.gz
where I specify the path in the JSON file. I did not change anything in mri_dataset.py
except for the valid_mask
to restrict the interval where b-values are non-zero:
from curses import raw
from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
import random
import os
import numpy as np
import torch
from dipy.io.image import save_nifti, load_nifti
from matplotlib import pyplot as plt
from torchvision import transforms, utils
class MRIDataset(Dataset):
def __init__(self, dataroot, valid_mask, phase='train', image_size=128, in_channel=1, val_volume_idx=50, val_slice_idx=40,
padding=1, lr_flip=0.5, stage2_file=None):
self.padding = padding // 2
self.lr_flip = lr_flip
self.phase = phase
self.in_channel = in_channel
# read data
raw_data, _ = load_nifti(dataroot) # width, height, slices, gradients
raw_data = raw_data.astype(np.float32)
print('Loaded data of size:', raw_data.shape)
# normalize data
raw_data = raw_data.astype(np.float32) / np.max(raw_data, axis=(0,1,2), keepdims=True)
# parse mask
assert type(valid_mask) is (list or tuple) and len(valid_mask) == 2
# mask data
raw_data = raw_data[:,:,:,valid_mask[0]:valid_mask[1]]
self.data_size_before_padding = raw_data.shape
self.raw_data = np.pad(raw_data, ((0,0), (0,0), (in_channel//2, in_channel//2), (self.padding, self.padding)), mode='wrap')
# running for Stage3?
if stage2_file is not None:
print('Parsing Stage2 matched states from the stage2 file...')
self.matched_state = self.parse_stage2_file(stage2_file)
else:
self.matched_state = None
# transform
if phase == 'train':
self.transforms = transforms.Compose([
transforms.ToTensor(),
#transforms.Resize(image_size),
transforms.RandomVerticalFlip(lr_flip),
transforms.RandomHorizontalFlip(lr_flip),
transforms.Lambda(lambda t: (t * 2) - 1)
])
else:
self.transforms = transforms.Compose([
transforms.ToTensor(),
#transforms.Resize(image_size),
transforms.Lambda(lambda t: (t * 2) - 1)
])
# prepare validation data
if val_volume_idx == 'all':
self.val_volume_idx = range(raw_data.shape[-1])
elif type(val_volume_idx) is int:
self.val_volume_idx = [val_volume_idx]
elif type(val_volume_idx) is list:
self.val_volume_idx = val_volume_idx
else:
self.val_volume_idx = [int(val_volume_idx)]
if val_slice_idx == 'all':
self.val_slice_idx = range(0, raw_data.shape[-2])
elif type(val_slice_idx) is int:
self.val_slice_idx = [val_slice_idx]
elif type(val_slice_idx) is list:
self.val_slice_idx = val_slice_idx
else:
self.val_slice_idx = [int(val_slice_idx)]
def parse_stage2_file(self, file_path):
results = dict()
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
info = line.strip().split('_')
volume_idx, slice_idx, t = int(info[0]), int(info[1]), int(info[2])
if volume_idx not in results:
results[volume_idx] = {}
results[volume_idx][slice_idx] = t
return results
def __len__(self):
if self.phase == 'train' or self.phase == 'test':
return self.data_size_before_padding[-2] * self.data_size_before_padding[-1] # num of volumes
elif self.phase == 'val':
return len(self.val_volume_idx) * len(self.val_slice_idx)
def __getitem__(self, index):
if self.phase == 'train' or self.phase == 'test':
# decode index to get slice idx and volume idx
volume_idx = index // self.data_size_before_padding[-2]
slice_idx = index % self.data_size_before_padding[-2]
elif self.phase == 'val':
s_index = index % len(self.val_slice_idx)
index = index // len(self.val_slice_idx)
slice_idx = self.val_slice_idx[s_index]
volume_idx = self.val_volume_idx[index]
raw_input = self.raw_data
if self.padding > 0:
raw_input = np.concatenate((
raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,volume_idx:volume_idx+self.padding],
raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,volume_idx+self.padding+1:volume_idx+2*self.padding+1],
raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding]]), axis=-1)
elif self.padding == 0:
raw_input = np.concatenate((
raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding-1]],
raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding]]), axis=-1)
# w, h, c, d = raw_input.shape
# raw_input = np.reshape(raw_input, (w, h, -1))
print("raw_input shape before slicing:", raw_input.shape)
if len(raw_input.shape) == 4:
raw_input = np.reshape(raw_input, (raw_input.shape[0], raw_input.shape[1], 1, -1))
raw_input = raw_input[:,:,0]
print("raw_input shape after slicing:", raw_input.shape)
raw_input = self.transforms(raw_input) # only support the first channel for now
# raw_input = raw_input.view(c, d, w, h)
ret = dict(X=raw_input[[-1], :, :], condition=raw_input[:-1, :, :])
if self.matched_state is not None:
ret['matched_state'] = torch.zeros(1,) + self.matched_state[volume_idx][slice_idx]
return ret
if __name__ == "__main__":
# hardi
valid_mask = np.zeros(56,)
valid_mask[20:] += 1
valid_mask = valid_mask.astype(np.bool8)
dataset = MRIDataset('/home/anar/DDM2/data/combined_dwi.nii.gz', valid_mask = [20, 56],
phase='train', val_volume_idx=40, padding=3)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
for i, data in enumerate(trainloader):
if i < 95 != 0:
continue
if i > 108:
break
img = data['X']
condition = data['condition']
img = img.numpy()
condition = condition.numpy()
vis = np.hstack((img[0].transpose(1,2,0), condition[0,[0]].transpose(1,2,0), condition[0,[1]].transpose(1,2,0)))
# plt.imshow(img[0].transpose(1,2,0), cmap='gray')
# plt.show()
# plt.imshow(condition[0,[0]].transpose(1,2,0), cmap='gray')
# plt.show()
# plt.imshow(condition[0,[1]].transpose(1,2,0), cmap='gray')
# plt.show()
plt.imshow(vis, cmap='gray')
plt.show()
#break
Hi, thank you for this amazing paper. I wanted to ask you very few questions to elaborate in very detail.
I have seen it in multiple places (i.e.
mri_dataset.py
) that you definevalid_mask = [10, 160]
. Considering your data size as (81, 106, 76, 160), are there any particular reasons you chooseval_volume _idx = 40
and selectvalid_mask = [10, 160]
in hardi150.json? The reason I asked is that I am working with (118, 118, 25, 56) 4D-diffusion data and there are some issues I fall into when defining mri_dataset.py as follows:Do you have any slightest ideas where could this originate from?