Closed BailiangJ closed 8 months ago
antsApplyTransforms
uses ITK's ResampleImageFilter. Given the overlap with MONAI (PyTorch) and the ITK community, they might be able to provide more informed guidance over at the ITK discourse forum.
@ntustison Thank you for the answer.
I have figured out the solution. I will latter attach a link to a jupyter notebook with the solution in case someone also has the same problem.
@ntustison Thank you for the answer.
I have figured out the solution. I will latter attach a link to a jupyter notebook with the solution in case someone also has the same problem.
Can you please share the code and solution @BailiangJ ? I am facing the same issue !!
@mahimoksha Yes, sure. I modified the code for the special case to the general case but haven't tested it.
from typing import Literal
# Special case with isotropic 1mm spacing
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
disp = sitk.ReadImage(flow_path)
direction = torch.tensor(disp.GetDirection()).reshape(3,3)
# sitk: (x,y,z), numpy(z,y,x)
disp_arr = sitk.GetArrayFromImage(disp)
disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
disp_tensor = torch.from_numpy(disp_arr).float()
if pkg == "niftyreg":
nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
# from nifty space to sitk space
# the x, y axes are mirrored
disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)
# transform displacement in physical point space to
# image index space (pytorch)
# should compute the inverse of affine matrix
# but since spacing is isotropic 1mm
# inverse of affine is inverse of direction
# direction = torch.linalg.inv(direction)
# usually the inverse of direction matrix is just itself
disp_tensor = torch.einsum("ij,jhwd->ihwd", direction, disp_tensor)
return disp_tensor.unsqueeze(0)
# General case
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
# displacement fields of ANTs and SimpleITK are both in Physical Point coordinate
disp = sitk.ReadImage(flow_path)
direction = torch.tensor(disp.GetDirection()).reshape(3,3)
spacing = torch.diag(torch.tensor(disp.GetSpacing()))
# the computed Affine matrix exclude the Origin
# since we are transforming the displacement vector in Physical Point coordinate
# to Image Index coordinate, the Origin is not needed
affine = torch.matmul(direction, spacing)
# mapping from Image Index coordinate to Physical Point coordinate, so we need the inverse
affine_inv = torch.linalg.inv(affine)
# sitk: (x,y,z) -> numpy:(z,y,x)
disp_arr = sitk.GetArrayFromImage(disp)
disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
disp_tensor = torch.from_numpy(disp_arr).float()
if pkg == "niftyreg":
nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
# from nifty space to sitk space
# the x, y axes are mirrored
disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)
# Physical Point space displacement to Image Index space displacement
disp_tensor = torch.einsum("ij,jhwd->ihwd", affine_inv, disp_tensor)
return disp_tensor.unsqueeze(0)
disp_tensor = load_flow()
# use the class I posted before
resampler = SpatialTransformer(size=, mode='bilinear')
warped_image = resampler(image,disp_tensor)
Hi, how to get the paramter of torch.nn.functional.affine_grid() from the ANTsTransform.parameters ? I cannot figure out what is the ANTsTransform.fixed_parameters, is it the rotate center ? when I use torch.nn.functional.affine_grid() and grid_sample() to get the warped result, it is not the same as Ants output.
Hi, @pypi20200320
From link I think fixed parameters if the rotation center point. (I am not 100% sure.)
If it is the case, I would suggest computing the displacement field directly from the matrix multiplication with an image grid rather than using affine_grid().
You can refer to airlab and take a look how they compute dense flow from affine transformation.
Hope this will help.
Dear @BailiangJ , thanks so much. I forgot to mention that what I am currently doing is, calculating the relationship between Ants and pytorch transformation parameters for Rigid/Affine registration.
I have refer these informations. And it seems like that the fixed parameters are really the rotate center in ants. So I calculate the physical offset when I put the center to the origin of the fixed image in Ants. And transform the physical offset to pixel offset, then get the offset in Pytorch in which the rotae center is the center of the image.
But I still cannot get the right result, my output "ants_warped.mha" and "torch_warped.mha" are not the same.
Could you help to check what's the problem ? Here is my code:
base_path = '/data/result/test_data/'
name = '0000338846'
fixed = ants.image_read(base_path + f"{name}/{name}_fixed.mha")
moving = ants.image_read(base_path + f"{name}/{name}_moving.mha")
tx = ants.registration(fixed, moving, syn_metric='CC',
syn_sampling=3, # radius
type_of_transform='Rigid',
grad_step=0.25,
flow_sigma=0.9,
total_sigma=0.2,
reg_iterations=(210, 210, 210),
verbose=False)
warped = tx['warpedmovout']
ants.image_write(warped, base_path + f"{name}/{name}_ants_warped.mha")
matrix = ants.read_transform(tx['fwdtransforms'][0])
center = matrix.fixed_parameters
matrix = matrix.parameters.reshape(4, 3)
rotate = matrix[:3, :3]
b_offset = matrix[3, :]
phy_offset = b_offset + center - rotate @ center
moving_data = moving.numpy()
moving_data = np.expand_dims(np.expand_dims(moving_data, 0), 0)
moving_data = torch.tensor(moving_data, dtype=torch.float)
## phy_offset to pixel_offset
direction = fixed.direction.reshape(3, 3)
spacing = np.array(fixed.spacing)
origin = np.array(fixed.origin)
img_size = np.array(fixed.numpy().shape)
print(origin, spacing, direction, img_size)
pixel_offset = phy_offset / spacing
print('pixel_offset:', pixel_offset)
## rotate center is the image_center in pytorch
torch_center = 0.5 * img_size
torch_offset = pixel_offset + torch_center - rotate @ torch_center
print('torch offset:', torch_offset)
torch_matrix = np.eye(4)
torch_matrix[:3, :3] = rotate
torch_matrix[:3, 3] = torch_offset
print('torch_matrix:', torch_matrix)
# offset normalize to [-1, 1]
norm_m = np.eye(4)
norm_m[0, 0] = 2 / img_size[0]
norm_m[1, 1] = 2 / img_size[1]
norm_m[2, 2] = 2 / img_size[2]
norm_m[:3, 3] = -1
inv_norm = np.linalg.inv(norm_m)
torch_matrix = norm_m @ torch_matrix @ inv_norm
print('torch_matrix:', torch_matrix)
torch_matrix = torch.tensor(torch_matrix[:3, :], dtype=torch.float).unsqueeze(0)
# Interpolation
grid = F.affine_grid(torch_matrix,
moving_data.size(),
align_corners=False)
x_aligned = F.grid_sample(moving_data,
grid=grid,
mode='bilinear',
padding_mode='border',
align_corners=False)
aligned = ants.from_numpy(x_aligned.detach().squeeze().numpy().astype(np.float32), has_components=False)
ants.set_origin(aligned, list(origin))
ants.set_spacing(aligned, list(spacing))
ants.image_write(aligned, base_path + f"{name}/{name}_torch_warped.mha")
Hi @pypi20200320 ,
I have never used affine_grid and I don't know the behaviour of it exactly, e.g., whether the rotation center is fixed to the center of the image, whether the input translation to affine_grid should be normalized to [-1.1].
I can see that you are trying to turn
the affine transformation given by ANTs which is defined in physical space
to
the affine transformation that is compatiable with PyTorch, which is defined in the image voxel space.
I think that's more work to do than turning the offset from mm unit (physical point) into voxel unit (image/voxel index) if the direction of the images is not identity matrix. (the transformation between physical space and image space)
If you just want to get PyTorch to transform the image using the ANTs affine transformation, you could ask ANTs compute the dense displacement field from it and reuse my code.
If you want to compute the corresponding PyTorch affine matrix from the ANTs affine matrix, you can try to solve the following equation and figure out what should be the matrix applied to the image/voxel index.
@BailiangJ
Hi,
I was trying to use torch.nn.grid_sample to resample the moving image with the displacement field obtained from ants.registration.
However, I can't get the correct resampled result. This might due to my not fully understanding of the arrangement (direction / internal fixed parameters) of ants displacement transform. It would be very nice of you if you could provide some insights. Thanks a lot!
Here is the code:
# pytorch resampler class class SpatialTransformer(nn.Module): """ N-D Spatial Transformer """ def __init__(self, size, mode='bilinear'): super().__init__() self.mode = mode # create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors) grid = torch.stack(grids) grid = torch.unsqueeze(grid, 0) grid = grid.type(torch.FloatTensor) # registering the grid as a buffer cleanly moves it to the GPU, but it also # adds it to the state dict. this is annoying since everything in the state dict # is included when saving weights to disk, so the model files are way bigger # than they need to be. so far, there does not appear to be an elegant solution. # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict self.register_buffer('grid', grid) def forward(self, src, flow): # new locations new_locs = self.grid + flow shape = flow.shape[2:] # need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) # move channels dim to last position # also not sure why, but the channels need to be reversed if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1, 0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2, 1, 0]] return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode) # ants registration fixed_image = ants.image_read(fixed_path) moving_image = ants.image_read(moving_path) tx = ants.registration(fixed=fixed_image, moving=moving_image, syn_metric='CC', syn_sampling=3, # radius type_of_transform='SyNOnly', grad_step=0.25, flow_sigma=0.9, total_sigma=0.2, reg_iterations=(210, 210, 210), verbose=True ) # save the displacement field disp = ants.image_read(tx['fwdtransforms'][0]) offset = ants.read_transform(tx['fwdtransforms'][1]) offset = offset.parameters.reshape(4,3)[-1,:] print(offset) disp_arr = disp.numpy() disp_arr[...,[0,1,2]] += offset disp = ants.from_numpy(disp_arr,origin=disp.origin,spacing=disp.spacing,direction=disp.direction,has_components=disp.has_components,is_rgb=disp.is_rgb) ants.image_write(disp, "ants_flow.nii.gz") # resample with pytorch resampler = SpatialTransformer(size=(160,192,224), mode='bilinear') flow = ants.image_read( "ants_flow.nii.gz") flow_arr = flow.numpy() warped_moving_oh = resampler(moving_oh, flow_tensor.unsqueeze(0)) compute_dice(fixed_oh,warped_moving_oh).mean()
I understand that the displacement field produced by ants.registration is in the Physical point space, and the pytorch resampler works in image index space. But since the spacing is (1.0,1.0,1.0), they are the same from my point of view.
disp.origin:(-80.0, 112.0, 96.0) disp.direction: array([[ 1., 0., 0.], [ 0., -0., -1.], [ 0., -1., 0.]]) disp.spacing:(1.0,1.0,1.0)
The initial mean dice of the data pair is 0.54, by doing
warped_label = ants.apply_transforms(fixed=fixed_label,moving=moving_label,transformlist=tx['fwdtransforms'],interpolator='nearestNeighbor')
I get ~0.76, but loading the flow field and use the pytorch resampler I only get ~0.5.
Hi ,thanks your work! My ultimate goal is to get an affine deformation field that works correctly on SpatialTransformer. Even though my work is to focus on 2d image,it seems that i also meet the same problem .
I have some questions:
You can read the description of the affine transform here. You might want to ask over at the ITK discourse forum as someone might already have torch code to do this..
You can read the description of the affine transform here. You might want to ask over at the ITK discourse forum as someone might already have torch code to do this..
Ok, thank you very much for your help. I wish you a happy life and work!
Hi @anamazingclown ,
This link might be helpful regarding getting the displacement field from ANTs affine transformation. After getting the displacement field, you can modify the load_flow function I posted from 3D to 2D and use it with the SpatialTransformer I posted.
This might also be relevant.
Ok, I will try, thank you very much!!!!!!
Hi,
I was trying to use torch.nn.grid_sample to resample the moving image with the displacement field obtained from ants.registration.
However, I can't get the correct resampled result. This might due to my not fully understanding of the arrangement (direction / internal fixed parameters) of ants displacement transform. It would be very nice of you if you could provide some insights. Thanks a lot!
Here is the code:
I understand that the displacement field produced by ants.registration is in the Physical point space, and the pytorch resampler works in image index space. But since the spacing is (1.0,1.0,1.0), they are the same from my point of view.
The initial mean dice of the data pair is 0.54, by doing
warped_label = ants.apply_transforms(fixed=fixed_label,moving=moving_label,transformlist=tx['fwdtransforms'],interpolator='nearestNeighbor')
I get ~0.76, but loading the flow field and use the pytorch resampler I only get ~0.5.