Closed Noob-PeNguiN closed 3 weeks ago
Hello Zack,
Thanks for your interest in our work.
As you will have noticed, cc.RandomSlicewiseAffineTransform()
is used in the synthesis of simulated slice stacks, and spacing
specifies the spacing of the slices in the MRI acquisition forward model. In a true MRI forward model, there would also be an acquisition PSF, but we assume a box-car pre-filter is used -- unweighted integration across the width of each slice.
While we could have worked directly with a stack of slices (each of which is 1 voxel thin), we found it useful to represent each slice as an MR "slab" of thickness spacing
voxels, so cc.RandomSlicewiseAffineTransform()
essentially samples a stack of slices every spacing
voxels and also replicates these slices spacing
times, in a back-projection-esque manner. The subsample
factor just subsamples all voxel data by the specified factor and it is there because we do not actually need full resolution slice stacks to predict good underlying motion fields.
I'm not sure what may be happening with your data, as once you have correctly resampled the data to 1:4
it should technically look like the FeTA training data (you could check this). I would make sure that your slice stack has some padding around the data and that the padding is in multiples of 4 voxels. You could have a look at the inference script again to see how the padding should look.
Working with anisotropic data directly is a possibility, but I think that will significantly complicate the splatting and slicing logic. As a related example, a typical 4x image super-resolution network would take in a 4x bilinearly or trilinearly upsampled image as input and try to recover missing details, rather than an original-resolution image and spit out a 4x larger shape image. The slice replication is a similar idea.
Now that I'm back from my travels I'll try to add some pre-processing logic to my code so that things like the spacing ratio are handled automatically. Happy to answer any other questions you may have and also to just chat.
Best,
Greetings.
Many thanks for your detailed reply and sorry for my late reply!
I've been doing further study on your model with our dataset, and have a few more questions.
To start with, continuing to our last discussion, working with isotropic data is great and I love your insight about it, but what if we only have anisotropic data, is it possible that we train the model with it? Based on how I understand the training process, we need to slice
the data ourselves. For example, we slice
the spacing ratio from 1:1
to 1:5
, such that we can know the ground truth of the motion. However, I wonder if we can still do that working with anisotropic data. It seems that we can't really do the slicing process in this case, as the spacing ratio will be increased if we slice again, which to my understanding will ruin the spacing ratio for our desired input ratio. For example, if my data is already 1:5
and I still want a 1:5
ratio as input in inference, I shouldn't slice it more as it won't be 1:5
in training hence unusable for inferencing an 1:5
data.
Of course, there are tons of isotropic data we can use for training out there, but It's very common we have only anisotropic data when working with fetal MRI. So maybe the question is: is it of significance that I use similar data in training and inference for a better performance? If not, we can actually train the model with other isotropic dataset and then use the model for our own data.
Additionally, is registering the data to the same template necessary both in training and inferencing? I think it's necessary in training because the model has no loss for global rigid motion. But in inference, we only use one stack as input, so we actually don't need to remove the global rigid motion with registration? Btw, another thing about registration is that, when using anisotropic data to register to the crl dataset which is isotropic, should I downsample the crl data so that its spacing would match our anisotropic input data? I've tried downsampling it and not downsampling it, none of the outcome is satisfying and it confused me.
Anyway, I've tried to add some padding around the slice stack and also change the parameter spacing
to 5
to match our voxel spacing but the result was still terrible. (I'll put them below)
For your information, here's my code for inference. I’ve customized my dataset and it can be created in the code by calling the function meng3d_svr()
(I'll put the code below). I don't think I really understand the adding some padding around the data thing you mentioned in your reply.
Many many many thanks for taking the time to read this.
import torch
import torch.nn.functional as F
import models
import models.losses
import cornucopia as cc
import os.path
import numpy as np
import matplotlib.pyplot as plt
from datasets import feta3d0_4_svr
from datasets.mt import meng3d_svr
from datasets.transforms import BoundingBox3d
from pytorch_lightning import seed_everything
import nibabel as nib
def get_padding(size, multiple):
if size % multiple == 0:
return 0, 0
else:
pad_size = multiple - (size % multiple)
return pad_size // 2, pad_size - pad_size // 2
# save as png
def visualize_slices(volume, title, filename):
if volume.dim() == 5:
mid_slice = volume[0, 0, volume.size(2) // 2, :, :].cpu().numpy() # (batch, channels, depth, height, width)
elif volume.dim() == 4:
mid_slice = volume[0, volume.size(1) // 2, :, :].cpu().numpy() # (batch, depth, height, width)
else:
raise ValueError(f"Unsupported tensor dimension: {volume.dim()}")
plt.figure(figsize=(10, 5))
plt.imshow(mid_slice, cmap='gray')
plt.title(title)
plt.axis('off')
plt.savefig(filename)
print(f"Image saved to {filename}")
ckpt_path_motion = 'checkpoints/feta3d0_4_svr_flow_SNet3d0_1024_l22_loss_affine_invariant_bulktrans0_bulkrot180_trans10_rot20_250k/last.ckpt'
ckpt_path_interp = 'checkpoints/feta3d_4_inpaint_unet3d_320_l2_loss_250k_192/last.ckpt'
torch.set_grad_enabled(False)
seed_everything(2, workers=True)
# motion estimate model
trainee = models.segment(model=models.flow_SNet3d0_1024())
trainee.load_state_dict(torch.load(ckpt_path_motion)['state_dict'])
motion_model = trainee.model.cuda()
# interpolate model
trainee = models.segment(model=models.unet3d_320(1,1))
trainee.load_state_dict(torch.load(ckpt_path_interp)['state_dict'])
interp_model = trainee.model.cuda()
val_set = meng3d_svr(spacing=5, subsample=1)[2]
l2loss_2 = 0
# for imgnum in range(96):
with torch.no_grad():
# get data from validation dataset
# [1,18,261,274]
volume, labels, _ = val_set.__getitem__(1, gpu=False)
# padding
v_depth = get_padding(volume.size(1), 4)
v_height = get_padding(volume.size(2), 4)
v_width = get_padding(volume.size(3), 4)
l_depth = get_padding(labels.size(1), 4)
l_height = get_padding(labels.size(2), 4)
l_width = get_padding(labels.size(3), 4)
padding_v = (v_width[0], v_width[1], v_height[0], v_height[1], v_depth[0], v_depth[1]) # padding starts from the last dimention
padding_l = (l_width[0], l_width[1], l_height[0], l_height[1], l_depth[0], l_depth[1])
volume = F.pad(volume, padding_v)
labels = F.pad(labels, padding_l)
# transformations
slices, target = val_set.transforms(volume.cuda(), labels.cuda(), cpu=False, gpu=True)
# padding
pad_depth = get_padding(slices.size(1), 4)
pad_height = get_padding(slices.size(2), 4)
pad_width = get_padding(slices.size(3), 4)
target_depth = get_padding(target.size(1), 4)
target_height = get_padding(target.size(2), 4)
target_width = get_padding(target.size(3), 4)
padding_slice = (pad_width[0], pad_width[1], pad_height[0], pad_height[1], pad_depth[0], pad_depth[1])
padding_target = (target_width[0], target_width[1], target_height[0], target_height[1], target_depth[0], target_depth[1])
slices = F.pad(slices, padding_slice)
target = F.pad(target, padding_target)
# For efficient inference, we will subsample our slices and target by
# a factor of 2 and crop the volumes tighter to the foreground region
# slices_2, target_2 = BoundingBox3d(2)(slices[None,:,::2,::2,::2], 0.5 * target[None,:,::2,::2,::2], target[-1,::2,::2,::2])
# we don't need downsample for our dataset (?)
slices_2, target_2 = BoundingBox3d(2)(slices[None, :, :, :, :], 0.5 * target[None, :, :, :, :],
target[-1, :, :, :])
motion_2 = motion_model(slices_2)
# This is an optional step to factor out global rigid motion. You
# can alternatively align the splatted result to an atlas after
# interpolation.
motion_2 = motion_model.compensate(motion_2, target_2)
loss_2 = models.losses.l21_loss_affine_invariant(motion_2, target_2, eps=0).item()
# print('Motion error for image %2d: %f voxels' % (imgnum, loss_2))
print('Motion error for image %2d: %f voxels' % (1, loss_2))
l2loss_2 += loss_2
# Upsample the motion to the resolution of the original slice
# stack data and splat the slice data at the original resolution
motion = motion_model.upsample_flow(motion_2)
splat = motion_model.unet3.splat(slices[None,:1], motion.flip(1), mask=slices[None,1:])
splat = splat[:,:-1] / (splat[:,-1:] + 1e-12 * splat[:,-1:].max().item()) # normalize
truth = motion_model.unet3.splat(slices[None,:1], target[None][:,:3].flip(1), mask=slices[None,1:])
truth = truth[:,:-1] / (truth[:,-1:] + 1e-12 * truth[:,-1:].max().item()) # normalize
splat_inpaint = interp_model(splat)
truth_inpaint = interp_model(truth)
# visualize slices, splat_inpaint, truth_inpaint (insert your own code here)
visualize_slices(slices, "Input Slices", "input_slices.png")
visualize_slices(splat_inpaint, "Splat Inpaint", "splat_inpaint.png")
visualize_slices(truth_inpaint, "Truth Inpaint", "truth_inpaint.png")
affine = np.eye(4)
# just testing one stack, use a hardcoded affine.
t_affine = np.array(
[[5, 0, 0, 0],
[0, 1.25, 0, 0],
[0, 0, 1.25, 0],
[0, 0, 0, 1]])
ori_slice = volume.cpu()
ori_slices =ori_slice.squeeze(0)
ori_array = ori_slices.numpy()
ori_nii = nib.Nifti1Image(ori_array, affine=t_affine)
nib.save(ori_nii, '/opt/data/private/svr_mt/results/'+os.path.basename('ori_padded_14_spacing5_taf'))
slice_input = slices.cpu()
slices_input = slice_input.squeeze(0)
slice_array0 = slices_input.numpy()
nii_img = nib.Nifti1Image(slice_array0[0], affine=t_affine)
nib.save(nii_img, '/opt/data/private/svr_mt/results/'+os.path.basename('motion_padded_14_spacing5_taf'))
slice_in = splat_inpaint.cpu()
slices_3 = slice_in.squeeze(0).squeeze(0)
slice_array = slices_3.numpy()
nii_img2 = nib.Nifti1Image(slice_array, affine=t_affine)
nib.save(nii_img2, '/opt/data/private/svr_mt/results/'+os.path.basename('splated_padded_14_spacing5_taf'))
slice_truth = truth_inpaint.cpu()
slices_4 = slice_truth.squeeze(0).squeeze(0)
slice_array2 = slices_4.numpy()
nii_img3 = nib.Nifti1Image(slice_array2, affine=t_affine)
nib.save(nii_img3, '/opt/data/private/svr_mt/results/'+os.path.basename('truth_padded_14_spacing5_taf'))
# nishow([slices[0].cpu(), splat[0,0].cpu(), splat_inpaint[0,0].cpu(), truth[0,0].cpu(), truth_inpaint[0,0].cpu()])
# print('Motion error average is %f voxels' % (l2loss_2 / 96))
print('Motion error average is %f voxels' % (l2loss_2))
And here's my code for customized dataset
class MengTN(VisionDataset):
def __init__(
self,
root: str = '../data/lmt_fetal_nii_1_orishape',
image_set: str = 'test',
split: str ='',
stride: int = 1,
out_shape: list = [340,333,18], # todo
numinput: int = 1,
numclass: int = 1,
multiply: int = 1,
cut: int = 0,
weights = 1,
transforms: Optional[Callable] = None,
**kwargs
):
super().__init__(root, transforms)
image_sets = ['train', 'test', 'val']
self.stride = stride
self.multiply = multiply
self.image_set = image_set
self.image_file = '%s_%s.nii.gz'
self.label_file = '%s_%s_mask.nii.gz'
self.numinput = numinput
self.numclass = numclass
with open(os.path.join('./datasets/lmt_fetal', image_set), 'r') as f:
path_names = [p.strip() for p in f.readlines()]
# path_names = [path_names[i] for i in self.split] if isinstance(split, list) else path_names
self.images = [os.path.join(self.root, self.image_file % (p.split('_')[0], p.split('_')[1])) for p in path_names]
self.labels = [os.path.join(self.root, self.label_file % (p.split('_')[0], p.split('_')[1])) for p in path_names]
def __getitem__(self, index: int, cpu=True, gpu=False) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
# make sure index is in range with %
image = self.images[self.stride*index % len(self.images)]
label = self.labels[self.stride*index % len(self.labels)]
# label = label if os.path.isfile(label) else self.labels[self.stride*index]
img = np.asarray(nib.load(image).dataobj, dtype=np.float32)[None] #fs.Volume.read(image).data[None]
img = img.transpose(0,3,1,2)
target = np.asarray(nib.load(label).dataobj, dtype=np.int8)[None] #fs.Volume.read(label).data[None]
target = target.transpose(0,3,1,2)
if self.transforms is not None:
img, target = self.transforms(img, target, cpu=cpu, gpu=gpu)
return img, target, index
def __len__(self) -> int:
return len(self.images) * self.multiply
def __outshape__(self) -> list:
return self.out_shape
def __numinput__(self) -> int:
return self.numinput
def __weights__(self):
return self.weights
def __numclass__(self) -> int:
return self.numclass
def meng3d_svr(root='/opt/data/private/svr-master/data/lmt_fetal_nii_1_orishape', slice=1, spacing=2, subsample=2, cut=0, **kwargs):
trainformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne(), transforms.RandAffine3dSlice(spacing=spacing, zooms=(-0.1,0.1), subsample=subsample, slice=slice)], gpuindex=1)
transformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne(), transforms.RandAffine3dSlice(spacing=spacing, zooms=(-0.1,0.1), subsample=subsample, slice=slice, augment=False)], gpuindex=1)
testsformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne()], gpuindex=1)
train = MengTN(root, image_set='train', multiply=5, transforms=trainformer, **kwargs)
valid = MengTN(root, image_set='val', multiply=8, transforms=transformer, **kwargs)
tests = MengTN(root,transforms=transformer, **kwargs)
return train, valid, tests
Here's the result of my inference, it's just so strange. However, it works on your data fine in the same environment.
Yes I agree, super strange, any way I can get a copy of the file splated_padded_14_spacing5_taf.nii to have a look at? (if it's been anonymized etc..)
Of course! Here's the file and my input. splated_padded_14_spacing5_taf.nii.gz
I had a conference submission last week that took up a bulk of my time—I'm looking into this as we speak. Thanks for you patience!
Just take your time bro. Really appreciate your time!
Can be a lot better still but at least it looks like a brain.. We have a slightly updated model that does coordinate convolutions, will ping you again once the weights for the updated model are available!
Oh, it works! I would say it has a reasonable quality because the resolution of the thick layers of the input is fairly low and the output resolution looks like some compromise between thick and thin layer resolution. Looking forward to your updated model! And again, many many thanks for your effort in this issue! You're among the most responsible github repository owners!
Greetings.
I've been studying your code for days and found that you add motion to the slicing process mainly via RandAffine3dSlice with the package cornucopia.
However, I can't be sure of the difference between spacing and subsample in the function parameter of
cc.RandomSlicewiseAffineTransform()
. I've debugged the code with your data and found out that withspacing=4
andsubsample=2
(defined in thefeta3d0_4_svr()
), the size of the tensor changes from[1,256,256,256]
to[1,128,128,128]
after calling the functioncc.RandomSlicewiseAffineTransform()
, which I guess it's becausesubsample=2
. But what role does spacing play in the function parameters? I've noticed that you mentioned you scale the 2D input slices to a4:1
slab thickness to voxel spacing ratio, to handle different slice spacings in real data. Does the spacing here to do just that?Anyway, I am concerned about this problem because the slice spacing ratio of my data is not
1:4
, but1:5
(1.0mm:5.0mm
). And I've tested my data on your model with the pretrain weights posted on the github page by simply downsample the in-slice resolution to1.25mm
, so that it'll be1.25mm:5.0mm
. Yet the result was unusable (it doesn't even have the shape of a brain, like the result of SVRnet in the figure 9 of your paper but worse). btw, the model works on your data just fine in the same environment. So, I begin to think that maybe I need to train a new model that targeting the slice spacing ratio of1:5
, and have problems with how I can change the voxel spacing ratio in training, hence the question above.Lastly, I wonder if I can directly train the model with my data which is anisotropy and has
1:5
slice spacing ratio, because I've noticed that during training we downsample the train data, and in inference we downsample the input too, so why not just use the anisotropy data?Sorry for my tons of questions and many thanks for taking the time to read this.
Zack.