Open wentaozhu opened 5 years ago
Or maybe the conversion between reg_resample and itk ResampleImageFilter?
Hi @mmodat , are there any updates on this issue?
Hi @wentaozhu , did you manage to figure out the solution to the problem? I am also facing the same question.
Thanks a lot.
Hi, I have figured out the solution. I will later attach a link to the jupyter notebook with the solution in case someone is interested in the future. Thanks a lot.
class SpatialTransformer(nn.Module):
"""
N-D Spatial Transformer
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)
# registering the grid as a buffer cleanly moves it to the GPU, but it also
# adds it to the state dict. this is annoying since everything in the state dict
# is included when saving weights to disk, so the model files are way bigger
# than they need to be. so far, there does not appear to be an elegant solution.
# see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
self.register_buffer('grid', grid)
def forward(self, src, flow):
# new locations
new_locs = self.grid + flow
shape = flow.shape[2:]
# need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
# reverse the order of flow components
# due to the property of grid_sample
# x,y,z -> z,y,x
# flow: [X, Y, Z, [z,y,x]]
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
from typing import Literal
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
disp = sitk.ReadImage(flow_path)
direction = torch.tensor(disp.GetDirection()).reshape(3,3)
# sitk: (x,y,z), numpy(z,y,x)
disp_arr = sitk.GetArrayFromImage(disp)
disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
disp_tensor = torch.from_numpy(disp_arr).float()
if pkg == "niftyreg":
nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
# from nifty space to sitk space
# the x, y axes are mirrored
disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)
# transform displacement in physical point space to
# image index space (pytorch)
# should compute the inverse of affine matrix
# but since spacing is isotropic 1mm
# inverse of affine is inverse of direction
# direction = torch.linalg.inv(direction)
# usually the inverse of direction matrix is just itself
disp_tensor = torch.einsum("ij,jhwd->ihwd", direction, disp_tensor)
return disp_tensor.unsqueeze(0)
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
# displacement fields of ANTs and SimpleITK are both in Physical Point coordinate
disp = sitk.ReadImage(flow_path)
direction = torch.tensor(disp.GetDirection()).reshape(3,3)
spacing = torch.diag(torch.tensor(disp.GetSpacing()))
# the computed Affine matrix exclude the Origin
# since we are transforming the displacement vector in Physical Point coordinate
# to Image Index coordinate, the Origin is not needed
affine = torch.matmul(direction, spacing)
# mapping from Image Index coordinate to Physical Point coordinate, so we need the inverse
affine_inv = torch.linalg.inv(affine)
# sitk: (x,y,z) -> numpy:(z,y,x)
disp_arr = sitk.GetArrayFromImage(disp)
disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
disp_tensor = torch.from_numpy(disp_arr).float()
if pkg == "niftyreg":
nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
# from nifty space to sitk space
# the x, y axes are mirrored
disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)
# Physical Point space displacement to Image Index space displacement
disp_tensor = torch.einsum("ij,jhwd->ihwd", affine_inv, disp_tensor)
return disp_tensor.unsqueeze(0)
disp_tensor = load_flow()
resampler = SpatialTransformer(size=, mode='bilinear')
warped_image = resampler(image,disp_tensor)
When I am using map_coordinates and generated warp deformation field to get the transformed image, I find the tranformed image and that by NiftyReg are different.
Could you please provide the implementation and the main framework of reg_resample -def? Could you please provide an example using map_coordinates to warp the image?
I have converted the control points (CPP) field to deformation field. But I still fail to convert the same transformed image.
Thank you!