cwmok / DIRAC

This is the official Pytorch implementation of "Unsupervised Deformable Image Registration with Absent Correspondences in Pre-operative and Post-Recurrence Brain Tumor MRI Scans" (MICCAI 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
36 stars 1 forks source link

About the transformation of images vs points #3

Closed dianatum closed 1 year ago

dianatum commented 1 year ago

Hello,

First, thanks for the open source code! I have a question regarding the image transform vs points (labels) transform. When transforming an image, you seem to reverse the grid channel order in "Functions.py":

Screenshot from 2022-11-23 13-10-17

I assume this is because of "torch.nn.functional.grid_sample" in the Spatial Transformer, where this seems to be required.

If I understand it correctly, you then also reverse the order of the channels in the deformation field, when transforming the labels (points):

Screenshot from 2022-11-23 13-16-19

Since this is now independent of the function used above, I don't really understand, why this is necessary at this point. Shouldn't the original deformation field already have the correct order with respect to the input image? Could you please help me with this issue?

Thanks a lot!

cwmok commented 1 year ago

Hi @dianatum,

You're almost there. "torch.nn.functional.grid_sample" expects the shape of grid is in (N, D, H, W, 3) = (N, D, H, W, [d, w, h]). The "generate_grid_unit" aims to generate the grid with the specific format.

Similarly, the deformation field F_X_Y has the shape (N, [d, w, h], D, H, W). However, if you want to visualize it, the default conventions of ITK-snap/other visualization tools are [h, w, d]. Therefore, we swap the order of F_X_Y in order to match the convention.

dianatum commented 1 year ago

Thanks a lot for your quick response, this makes sense to me. However, I would have one more question. I tried to apply some random affine transformation and see how the visualization behaves when transforming the points once swapped and once without. I used this code (parts originally provided by VoxelMorph):


img_sitk = sitk.ReadImage("../img.nii.gz") img = sitk.GetArrayFromImage(img_sitk) img = np.transpose(img,[2,1,0])

aff_3d = np.eye(4) aff_3d[:3, :3] += np.random.randn(3, 3)*0.1 aff_3d[:3, 3] = np.random.uniform(-10, 10, (3, )) aff_inv_3d = np.linalg.inv(aff_3d)

aff_sitk_3d = aff_3d[:3, :3]

annotations_3d = np.array([[122,106,73]])

im_keras_3d = img[np.newaxis, ..., np.newaxis] aff_keras_3d = aff_3d[np.newaxis, :3, :] annotations_keras_3d = annotations_3d[np.newaxis, ...]

im_warped = layers.SpatialTransformer(add_identity=False)([im_keras_3d, aff_keras_3d]) im_warped = im_warped[0, ..., 0]

im_warped = np.transpose(im_warped,[2,1,0]) im_warped_sitk = sitk.GetImageFromArray(im_warped) im_warped_sitk.CopyInformation(img_sitk) sitk.WriteImage(im_warped_sitk, "../warped_img.nii.gz")

field_inv_3d = utils.affine_to_shift(aff_inv_3d, img.shape, shift_center=True)[np.newaxis, ...]

field = field_inv_3d[0, ...].numpy()

Vol = np.transpose(field, (3,0,1,2)) full_flowVol = np.zeros(Vol.shape) full_flowVol[0] = Vol[2] full_flowVol[1] = Vol[1] full_flowVol[2] = Vol[0]

fixed_disp_x = map_coordinates(full_flowVol[2], annotations_3d.transpose()) fixed_disp_y = map_coordinates(full_flowVol[1], annotations_3d.transpose()) fixed_disp_z = map_coordinates(full_flowVol[0], annotations_3d.transpose()) lms_fixed_disp = np.array((fixed_disp_x, fixed_disp_y, fixed_disp_z)).transpose() warped_noswap = annotations_3d + lms_fixed_disp

fixed_disp_x = map_coordinates(full_flowVol[0], annotations_3d.transpose()) fixed_disp_y = map_coordinates(full_flowVol[1], annotations_3d.transpose()) fixed_disp_z = map_coordinates(full_flowVol[2], annotations_3d.transpose()) lms_fixed_disp = np.array((fixed_disp_x, fixed_disp_y, fixed_disp_z)).transpose() warped_swap = annotations_3d + lms_fixed_disp


Leading to output: warped_noswap: [[122.738174 95.65549 83.18481 ]] warped_swap: [[132.18480682 95.65548706 73.73817825]]

Leading to this visualization in ITK-Snap:

original: Screenshot from 2022-11-23 16-00-35

warped_noswap: Screenshot from 2022-11-23 16-01-14

warped_swap: Screenshot from 2022-11-23 16-01-42

There the "noswap" version seems to be correct. Does that mean that the deformation field in this case already has the correct order and does not need to be reversed?

Thank you very much for your help!!

cwmok commented 1 year ago

Hi @dianatum,

I am not familiar with Keras and the functions you used in your example. Maybe you should open an issue in VoxelMorph's repository?

dianatum commented 1 year ago

I will do this! Thanks a lot for your help!