KCL-BMEIS / niftyreg

This project contains command line tools to perform rigid, affine and non-linear registration of nifti or analyse images as well as utilities
BSD 3-Clause "New" or "Revised" License
141 stars 42 forks source link

How to use scipy map_coordinates and generated warp deformation field to get the transformed image? #55

Open wentaozhu opened 5 years ago

wentaozhu commented 5 years ago

When I am using map_coordinates and generated warp deformation field to get the transformed image, I find the tranformed image and that by NiftyReg are different.

Could you please provide the implementation and the main framework of reg_resample -def? Could you please provide an example using map_coordinates to warp the image?

I have converted the control points (CPP) field to deformation field. But I still fail to convert the same transformed image.

Thank you!

wentaozhu commented 5 years ago

Or maybe the conversion between reg_resample and itk ResampleImageFilter?

BailiangJ commented 1 year ago

Hi @mmodat , are there any updates on this issue?

Hi @wentaozhu , did you manage to figure out the solution to the problem? I am also facing the same question.

Thanks a lot.

BailiangJ commented 1 year ago

Hi, I have figured out the solution. I will later attach a link to the jupyter notebook with the solution in case someone is interested in the future. Thanks a lot.

BailiangJ commented 1 year ago
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # reverse the order of flow components
        # due to the property of grid_sample
        # x,y,z -> z,y,x
        # flow: [X, Y, Z, [z,y,x]]
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

from typing import Literal
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
    disp = sitk.ReadImage(flow_path)
    direction = torch.tensor(disp.GetDirection()).reshape(3,3)

    # sitk: (x,y,z), numpy(z,y,x)
    disp_arr = sitk.GetArrayFromImage(disp)
    disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
    disp_tensor = torch.from_numpy(disp_arr).float()

    if pkg == "niftyreg":
        nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
        # from nifty space to sitk space
        # the x, y axes are mirrored
        disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)

    # transform displacement in physical point space to 
    # image index space (pytorch)
    # should compute the inverse of affine matrix 
    # but since spacing is isotropic 1mm
    # inverse of affine is inverse of direction
    # direction = torch.linalg.inv(direction)
    # usually the inverse of direction matrix is just itself
    disp_tensor = torch.einsum("ij,jhwd->ihwd", direction, disp_tensor)

    return disp_tensor.unsqueeze(0)

def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
    # displacement fields of ANTs and SimpleITK are both in Physical Point coordinate
    disp = sitk.ReadImage(flow_path)
    direction = torch.tensor(disp.GetDirection()).reshape(3,3)
    spacing = torch.diag(torch.tensor(disp.GetSpacing()))
    # the computed Affine matrix exclude the Origin
    # since we are transforming the displacement vector in Physical Point coordinate
    # to Image Index coordinate, the Origin is not needed
    affine = torch.matmul(direction, spacing)
    # mapping from Image Index coordinate to Physical Point coordinate, so we need the inverse
    affine_inv = torch.linalg.inv(affine)

    # sitk: (x,y,z) -> numpy:(z,y,x)
    disp_arr = sitk.GetArrayFromImage(disp)
    disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
    disp_tensor = torch.from_numpy(disp_arr).float()

    if pkg == "niftyreg":
            nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
            # from nifty space to sitk space
            # the x, y axes are mirrored
            disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)

    # Physical Point space displacement to Image Index space displacement
    disp_tensor = torch.einsum("ij,jhwd->ihwd", affine_inv, disp_tensor)
    return disp_tensor.unsqueeze(0)

disp_tensor = load_flow()
resampler = SpatialTransformer(size=, mode='bilinear')
warped_image = resampler(image,disp_tensor)