Closed dianatum closed 1 year ago
Hi @dianatum,
You're almost there. "torch.nn.functional.grid_sample" expects the shape of grid is in (N, D, H, W, 3) = (N, D, H, W, [d, w, h]). The "generate_grid_unit" aims to generate the grid with the specific format.
Similarly, the deformation field F_X_Y has the shape (N, [d, w, h], D, H, W). However, if you want to visualize it, the default conventions of ITK-snap/other visualization tools are [h, w, d]. Therefore, we swap the order of F_X_Y in order to match the convention.
Thanks a lot for your quick response, this makes sense to me. However, I would have one more question. I tried to apply some random affine transformation and see how the visualization behaves when transforming the points once swapped and once without. I used this code (parts originally provided by VoxelMorph):
img_sitk = sitk.ReadImage("../img.nii.gz") img = sitk.GetArrayFromImage(img_sitk) img = np.transpose(img,[2,1,0])
aff_3d = np.eye(4) aff_3d[:3, :3] += np.random.randn(3, 3)*0.1 aff_3d[:3, 3] = np.random.uniform(-10, 10, (3, )) aff_inv_3d = np.linalg.inv(aff_3d)
aff_sitk_3d = aff_3d[:3, :3]
annotations_3d = np.array([[122,106,73]])
im_keras_3d = img[np.newaxis, ..., np.newaxis] aff_keras_3d = aff_3d[np.newaxis, :3, :] annotations_keras_3d = annotations_3d[np.newaxis, ...]
im_warped = layers.SpatialTransformer(add_identity=False)([im_keras_3d, aff_keras_3d]) im_warped = im_warped[0, ..., 0]
im_warped = np.transpose(im_warped,[2,1,0]) im_warped_sitk = sitk.GetImageFromArray(im_warped) im_warped_sitk.CopyInformation(img_sitk) sitk.WriteImage(im_warped_sitk, "../warped_img.nii.gz")
field_inv_3d = utils.affine_to_shift(aff_inv_3d, img.shape, shift_center=True)[np.newaxis, ...]
field = field_inv_3d[0, ...].numpy()
Vol = np.transpose(field, (3,0,1,2)) full_flowVol = np.zeros(Vol.shape) full_flowVol[0] = Vol[2] full_flowVol[1] = Vol[1] full_flowVol[2] = Vol[0]
fixed_disp_x = map_coordinates(full_flowVol[2], annotations_3d.transpose()) fixed_disp_y = map_coordinates(full_flowVol[1], annotations_3d.transpose()) fixed_disp_z = map_coordinates(full_flowVol[0], annotations_3d.transpose()) lms_fixed_disp = np.array((fixed_disp_x, fixed_disp_y, fixed_disp_z)).transpose() warped_noswap = annotations_3d + lms_fixed_disp
fixed_disp_x = map_coordinates(full_flowVol[0], annotations_3d.transpose()) fixed_disp_y = map_coordinates(full_flowVol[1], annotations_3d.transpose()) fixed_disp_z = map_coordinates(full_flowVol[2], annotations_3d.transpose()) lms_fixed_disp = np.array((fixed_disp_x, fixed_disp_y, fixed_disp_z)).transpose() warped_swap = annotations_3d + lms_fixed_disp
Leading to output: warped_noswap: [[122.738174 95.65549 83.18481 ]] warped_swap: [[132.18480682 95.65548706 73.73817825]]
Leading to this visualization in ITK-Snap:
original:
warped_noswap:
warped_swap:
There the "noswap" version seems to be correct. Does that mean that the deformation field in this case already has the correct order and does not need to be reversed?
Thank you very much for your help!!
Hi @dianatum,
I am not familiar with Keras and the functions you used in your example. Maybe you should open an issue in VoxelMorph's repository?
I will do this! Thanks a lot for your help!
Hello,
First, thanks for the open source code! I have a question regarding the image transform vs points (labels) transform. When transforming an image, you seem to reverse the grid channel order in "Functions.py":
I assume this is because of "torch.nn.functional.grid_sample" in the Spatial Transformer, where this seems to be required.
If I understand it correctly, you then also reverse the order of the channels in the deformation field, when transforming the labels (points):
Since this is now independent of the function used above, I don't really understand, why this is necessary at this point. Shouldn't the original deformation field already have the correct order with respect to the input image? Could you please help me with this issue?
Thanks a lot!