seannz / svr

[CVPR2024] Fully convolutional slice-to-volume reconstruction for single-stack MRI
https://seaniyoung.com/publications.html
MIT License
27 stars 7 forks source link

Question about the parameters of RandAffine3dSlice and more. #7

Closed Noob-PeNguiN closed 3 weeks ago

Noob-PeNguiN commented 2 months ago

Greetings.

I've been studying your code for days and found that you add motion to the slicing process mainly via RandAffine3dSlice with the package cornucopia.

However, I can't be sure of the difference between spacing and subsample in the function parameter of cc.RandomSlicewiseAffineTransform(). I've debugged the code with your data and found out that with spacing=4 and subsample=2 (defined in the feta3d0_4_svr()), the size of the tensor changes from [1,256,256,256] to [1,128,128,128] after calling the function cc.RandomSlicewiseAffineTransform(), which I guess it's because subsample=2. But what role does spacing play in the function parameters? I've noticed that you mentioned you scale the 2D input slices to a 4:1 slab thickness to voxel spacing ratio, to handle different slice spacings in real data. Does the spacing here to do just that?

Anyway, I am concerned about this problem because the slice spacing ratio of my data is not 1:4, but 1:5 (1.0mm:5.0mm). And I've tested my data on your model with the pretrain weights posted on the github page by simply downsample the in-slice resolution to 1.25mm, so that it'll be 1.25mm:5.0mm. Yet the result was unusable (it doesn't even have the shape of a brain, like the result of SVRnet in the figure 9 of your paper but worse). btw, the model works on your data just fine in the same environment. So, I begin to think that maybe I need to train a new model that targeting the slice spacing ratio of 1:5, and have problems with how I can change the voxel spacing ratio in training, hence the question above.

Lastly, I wonder if I can directly train the model with my data which is anisotropy and has 1:5 slice spacing ratio, because I've noticed that during training we downsample the train data, and in inference we downsample the input too, so why not just use the anisotropy data?

Sorry for my tons of questions and many thanks for taking the time to read this.

Zack.

seannz commented 2 months ago

Hello Zack,

Thanks for your interest in our work.

As you will have noticed, cc.RandomSlicewiseAffineTransform() is used in the synthesis of simulated slice stacks, and spacing specifies the spacing of the slices in the MRI acquisition forward model. In a true MRI forward model, there would also be an acquisition PSF, but we assume a box-car pre-filter is used -- unweighted integration across the width of each slice.

While we could have worked directly with a stack of slices (each of which is 1 voxel thin), we found it useful to represent each slice as an MR "slab" of thickness spacing voxels, so cc.RandomSlicewiseAffineTransform() essentially samples a stack of slices every spacing voxels and also replicates these slices spacing times, in a back-projection-esque manner. The subsample factor just subsamples all voxel data by the specified factor and it is there because we do not actually need full resolution slice stacks to predict good underlying motion fields.

I'm not sure what may be happening with your data, as once you have correctly resampled the data to 1:4 it should technically look like the FeTA training data (you could check this). I would make sure that your slice stack has some padding around the data and that the padding is in multiples of 4 voxels. You could have a look at the inference script again to see how the padding should look.

Working with anisotropic data directly is a possibility, but I think that will significantly complicate the splatting and slicing logic. As a related example, a typical 4x image super-resolution network would take in a 4x bilinearly or trilinearly upsampled image as input and try to recover missing details, rather than an original-resolution image and spit out a 4x larger shape image. The slice replication is a similar idea.

Now that I'm back from my travels I'll try to add some pre-processing logic to my code so that things like the spacing ratio are handled automatically. Happy to answer any other questions you may have and also to just chat.

Best,

Noob-PeNguiN commented 1 month ago

Greetings.

Many thanks for your detailed reply and sorry for my late reply!

I've been doing further study on your model with our dataset, and have a few more questions.

To start with, continuing to our last discussion, working with isotropic data is great and I love your insight about it, but what if we only have anisotropic data, is it possible that we train the model with it? Based on how I understand the training process, we need to slice the data ourselves. For example, we slice the spacing ratio from 1:1 to 1:5, such that we can know the ground truth of the motion. However, I wonder if we can still do that working with anisotropic data. It seems that we can't really do the slicing process in this case, as the spacing ratio will be increased if we slice again, which to my understanding will ruin the spacing ratio for our desired input ratio. For example, if my data is already 1:5 and I still want a 1:5 ratio as input in inference, I shouldn't slice it more as it won't be 1:5 in training hence unusable for inferencing an 1:5 data.

Of course, there are tons of isotropic data we can use for training out there, but It's very common we have only anisotropic data when working with fetal MRI. So maybe the question is: is it of significance that I use similar data in training and inference for a better performance? If not, we can actually train the model with other isotropic dataset and then use the model for our own data.

Additionally, is registering the data to the same template necessary both in training and inferencing? I think it's necessary in training because the model has no loss for global rigid motion. But in inference, we only use one stack as input, so we actually don't need to remove the global rigid motion with registration? Btw, another thing about registration is that, when using anisotropic data to register to the crl dataset which is isotropic, should I downsample the crl data so that its spacing would match our anisotropic input data? I've tried downsampling it and not downsampling it, none of the outcome is satisfying and it confused me.

Anyway, I've tried to add some padding around the slice stack and also change the parameter spacing to 5 to match our voxel spacing but the result was still terrible. (I'll put them below)

For your information, here's my code for inference. I’ve customized my dataset and it can be created in the code by calling the function meng3d_svr() (I'll put the code below). I don't think I really understand the adding some padding around the data thing you mentioned in your reply.

Many many many thanks for taking the time to read this.

import torch
import torch.nn.functional as F
import models
import models.losses
import cornucopia as cc
import os.path
import numpy as np
import matplotlib.pyplot as plt

from datasets import feta3d0_4_svr
from datasets.mt import meng3d_svr
from datasets.transforms import BoundingBox3d
from pytorch_lightning import seed_everything
import nibabel as nib

def get_padding(size, multiple):
    if size % multiple == 0:
        return 0, 0 
    else:
        pad_size = multiple - (size % multiple)
        return pad_size // 2, pad_size - pad_size // 2

# save as png
def visualize_slices(volume, title, filename):
    if volume.dim() == 5:
        mid_slice = volume[0, 0, volume.size(2) // 2, :, :].cpu().numpy()  # (batch, channels, depth, height, width)
    elif volume.dim() == 4:
        mid_slice = volume[0, volume.size(1) // 2, :, :].cpu().numpy()  # (batch, depth, height, width)
    else:
        raise ValueError(f"Unsupported tensor dimension: {volume.dim()}")

    plt.figure(figsize=(10, 5))
    plt.imshow(mid_slice, cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.savefig(filename)
    print(f"Image saved to {filename}")

ckpt_path_motion = 'checkpoints/feta3d0_4_svr_flow_SNet3d0_1024_l22_loss_affine_invariant_bulktrans0_bulkrot180_trans10_rot20_250k/last.ckpt'
ckpt_path_interp = 'checkpoints/feta3d_4_inpaint_unet3d_320_l2_loss_250k_192/last.ckpt'

torch.set_grad_enabled(False)
seed_everything(2, workers=True)

# motion estimate model
trainee = models.segment(model=models.flow_SNet3d0_1024())
trainee.load_state_dict(torch.load(ckpt_path_motion)['state_dict'])
motion_model = trainee.model.cuda()

# interpolate model
trainee = models.segment(model=models.unet3d_320(1,1))
trainee.load_state_dict(torch.load(ckpt_path_interp)['state_dict'])
interp_model = trainee.model.cuda()

val_set = meng3d_svr(spacing=5, subsample=1)[2]

l2loss_2 = 0

# for imgnum in range(96):
with torch.no_grad():
    # get data from validation dataset
    # [1,18,261,274]
    volume, labels, _ = val_set.__getitem__(1, gpu=False) 

    # padding
    v_depth = get_padding(volume.size(1), 4)
    v_height = get_padding(volume.size(2), 4)
    v_width = get_padding(volume.size(3), 4)
    l_depth = get_padding(labels.size(1), 4)
    l_height = get_padding(labels.size(2), 4)
    l_width = get_padding(labels.size(3), 4)
    padding_v = (v_width[0], v_width[1], v_height[0], v_height[1], v_depth[0], v_depth[1]) # padding starts from the last dimention
    padding_l = (l_width[0], l_width[1], l_height[0], l_height[1], l_depth[0], l_depth[1])
    volume = F.pad(volume, padding_v)
    labels = F.pad(labels, padding_l)

    # transformations
    slices, target = val_set.transforms(volume.cuda(), labels.cuda(), cpu=False, gpu=True)

    # padding
    pad_depth = get_padding(slices.size(1), 4)
    pad_height = get_padding(slices.size(2), 4)
    pad_width = get_padding(slices.size(3), 4)
    target_depth = get_padding(target.size(1), 4)
    target_height = get_padding(target.size(2), 4)
    target_width = get_padding(target.size(3), 4)
    padding_slice = (pad_width[0], pad_width[1], pad_height[0], pad_height[1], pad_depth[0], pad_depth[1])
    padding_target = (target_width[0], target_width[1], target_height[0], target_height[1], target_depth[0], target_depth[1])
    slices = F.pad(slices, padding_slice)
    target = F.pad(target, padding_target)

    # For efficient inference, we will subsample our slices and target by
    # a factor of 2 and crop the volumes tighter to the foreground region
    # slices_2, target_2 = BoundingBox3d(2)(slices[None,:,::2,::2,::2], 0.5 * target[None,:,::2,::2,::2], target[-1,::2,::2,::2])

    # we don't need downsample for our dataset (?)
    slices_2, target_2 = BoundingBox3d(2)(slices[None, :, :, :, :], 0.5 * target[None, :, :, :, :],
                                        target[-1, :, :, :]) 
    motion_2 = motion_model(slices_2)

    # This is an optional step to factor out global rigid motion. You
    # can alternatively align the splatted result to an atlas after
    # interpolation.
    motion_2 = motion_model.compensate(motion_2, target_2)
    loss_2 = models.losses.l21_loss_affine_invariant(motion_2, target_2, eps=0).item()
    # print('Motion error for image %2d: %f voxels' % (imgnum, loss_2))
    print('Motion error for image %2d: %f voxels' % (1, loss_2))

    l2loss_2 += loss_2

    # Upsample the motion to the resolution of the original slice
    # stack data and splat the slice data at the original resolution
    motion = motion_model.upsample_flow(motion_2)

    splat = motion_model.unet3.splat(slices[None,:1], motion.flip(1), mask=slices[None,1:])
    splat = splat[:,:-1] / (splat[:,-1:] + 1e-12 * splat[:,-1:].max().item()) # normalize

    truth = motion_model.unet3.splat(slices[None,:1], target[None][:,:3].flip(1), mask=slices[None,1:])
    truth = truth[:,:-1] / (truth[:,-1:] + 1e-12 * truth[:,-1:].max().item())  # normalize

    splat_inpaint = interp_model(splat)
    truth_inpaint = interp_model(truth)

    # visualize slices, splat_inpaint, truth_inpaint (insert your own code here)
    visualize_slices(slices, "Input Slices", "input_slices.png")
    visualize_slices(splat_inpaint, "Splat Inpaint", "splat_inpaint.png")
    visualize_slices(truth_inpaint, "Truth Inpaint", "truth_inpaint.png")

    affine = np.eye(4)
    # just testing one stack, use a hardcoded affine.
    t_affine = np.array(
    [[5, 0, 0, 0],
        [0, 1.25, 0, 0],
        [0, 0, 1.25, 0],
        [0, 0, 0, 1]])
    ori_slice = volume.cpu()
    ori_slices =ori_slice.squeeze(0)
    ori_array = ori_slices.numpy()
    ori_nii = nib.Nifti1Image(ori_array, affine=t_affine)
    nib.save(ori_nii, '/opt/data/private/svr_mt/results/'+os.path.basename('ori_padded_14_spacing5_taf'))

    slice_input = slices.cpu()
    slices_input = slice_input.squeeze(0)
    slice_array0 = slices_input.numpy()
    nii_img = nib.Nifti1Image(slice_array0[0], affine=t_affine)
    nib.save(nii_img, '/opt/data/private/svr_mt/results/'+os.path.basename('motion_padded_14_spacing5_taf'))

    slice_in = splat_inpaint.cpu()
    slices_3 = slice_in.squeeze(0).squeeze(0)
    slice_array = slices_3.numpy()
    nii_img2 = nib.Nifti1Image(slice_array, affine=t_affine)
    nib.save(nii_img2, '/opt/data/private/svr_mt/results/'+os.path.basename('splated_padded_14_spacing5_taf'))

    slice_truth = truth_inpaint.cpu()
    slices_4 = slice_truth.squeeze(0).squeeze(0)
    slice_array2 = slices_4.numpy()
    nii_img3 = nib.Nifti1Image(slice_array2, affine=t_affine)
    nib.save(nii_img3, '/opt/data/private/svr_mt/results/'+os.path.basename('truth_padded_14_spacing5_taf'))
    # nishow([slices[0].cpu(), splat[0,0].cpu(), splat_inpaint[0,0].cpu(), truth[0,0].cpu(), truth_inpaint[0,0].cpu()])

# print('Motion error average is %f voxels' % (l2loss_2 / 96))
print('Motion error average is %f voxels' % (l2loss_2))

And here's my code for customized dataset

class MengTN(VisionDataset):
    def __init__(
        self,
        root: str = '../data/lmt_fetal_nii_1_orishape',
        image_set: str = 'test',
        split: str ='',
        stride: int = 1, 
        out_shape: list = [340,333,18], # todo
        numinput: int = 1,
        numclass: int = 1,
        multiply: int = 1,
        cut: int = 0,
        weights = 1,
        transforms: Optional[Callable] = None,
        **kwargs
):
        super().__init__(root, transforms)
        image_sets = ['train', 'test', 'val']
        self.stride = stride
        self.multiply = multiply
        self.image_set = image_set
        self.image_file = '%s_%s.nii.gz'
        self.label_file = '%s_%s_mask.nii.gz'
        self.numinput = numinput
        self.numclass = numclass

        with open(os.path.join('./datasets/lmt_fetal', image_set), 'r') as f:
            path_names = [p.strip() for p in f.readlines()]
        # path_names = [path_names[i] for i in self.split] if isinstance(split, list) else path_names
        self.images = [os.path.join(self.root, self.image_file % (p.split('_')[0], p.split('_')[1])) for p in path_names]
        self.labels = [os.path.join(self.root, self.label_file % (p.split('_')[0], p.split('_')[1])) for p in path_names]

    def __getitem__(self, index: int, cpu=True, gpu=False) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """

        # make sure index is in range with %
        image = self.images[self.stride*index % len(self.images)]
        label = self.labels[self.stride*index % len(self.labels)]
        # label = label if os.path.isfile(label) else self.labels[self.stride*index]

        img = np.asarray(nib.load(image).dataobj, dtype=np.float32)[None] #fs.Volume.read(image).data[None]
        img = img.transpose(0,3,1,2)
        target = np.asarray(nib.load(label).dataobj, dtype=np.int8)[None] #fs.Volume.read(label).data[None]
        target = target.transpose(0,3,1,2)

        if self.transforms is not None:
            img, target = self.transforms(img, target, cpu=cpu, gpu=gpu)

        return img, target, index

    def __len__(self) -> int:
        return len(self.images) * self.multiply 

    def __outshape__(self) -> list:
        return self.out_shape

    def __numinput__(self) -> int:
        return self.numinput  

    def __weights__(self):
        return self.weights

    def __numclass__(self) -> int:
        return self.numclass

def meng3d_svr(root='/opt/data/private/svr-master/data/lmt_fetal_nii_1_orishape', slice=1, spacing=2, subsample=2, cut=0, **kwargs):
    trainformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne(), transforms.RandAffine3dSlice(spacing=spacing, zooms=(-0.1,0.1), subsample=subsample, slice=slice)], gpuindex=1)
    transformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne(), transforms.RandAffine3dSlice(spacing=spacing, zooms=(-0.1,0.1), subsample=subsample, slice=slice, augment=False)], gpuindex=1)
    testsformer = transforms.Compose([transforms.ToTensor3d(), transforms.ScaleZeroOne()], gpuindex=1)

    train = MengTN(root, image_set='train', multiply=5, transforms=trainformer, **kwargs)
    valid = MengTN(root, image_set='val',   multiply=8, transforms=transformer, **kwargs)
    tests = MengTN(root,transforms=transformer, **kwargs)

    return train, valid, tests

Here's the result of my inference, it's just so strange. image However, it works on your data fine in the same environment. image

seannz commented 1 month ago

Yes I agree, super strange, any way I can get a copy of the file splated_padded_14_spacing5_taf.nii to have a look at? (if it's been anonymized etc..)

Noob-PeNguiN commented 1 month ago

Of course! Here's the file and my input. splated_padded_14_spacing5_taf.nii.gz

anonymized_stack.nii.gz

anonymized_mask.nii.gz

seannz commented 3 weeks ago

I had a conference submission last week that took up a bulk of my time—I'm looking into this as we speak. Thanks for you patience!

Noob-PeNguiN commented 3 weeks ago

Just take your time bro. Really appreciate your time!

seannz commented 3 weeks ago

Screenshot 2024-10-10 at 12 15 45 AM Can be a lot better still but at least it looks like a brain.. We have a slightly updated model that does coordinate convolutions, will ping you again once the weights for the updated model are available!

Noob-PeNguiN commented 3 weeks ago

Oh, it works! I would say it has a reasonable quality because the resolution of the thick layers of the input is fairly low and the output resolution looks like some compromise between thick and thin layer resolution. Looking forward to your updated model! And again, many many thanks for your effort in this issue! You're among the most responsible github repository owners!