ANTsX / ANTsPy

A fast medical imaging analysis library in Python with algorithms for registration, segmentation, and more.
https://antspyx.readthedocs.io
Apache License 2.0
638 stars 162 forks source link

Reproduce transform.apply_to_image by torch.nn.grid_sample #427

Closed BailiangJ closed 8 months ago

BailiangJ commented 1 year ago

Hi,

I was trying to use torch.nn.grid_sample to resample the moving image with the displacement field obtained from ants.registration.

However, I can't get the correct resampled result. This might due to my not fully understanding of the arrangement (direction / internal fixed parameters) of ants displacement transform. It would be very nice of you if you could provide some insights. Thanks a lot!

Here is the code:

# pytorch resampler class
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

# ants registration
fixed_image = ants.image_read(fixed_path)
moving_image = ants.image_read(moving_path)
tx = ants.registration(fixed=fixed_image,
                       moving=moving_image,
                       syn_metric='CC',
                       syn_sampling=3,  # radius
                       type_of_transform='SyNOnly',
                       grad_step=0.25,
                       flow_sigma=0.9,
                       total_sigma=0.2,
                       reg_iterations=(210, 210, 210),
                       verbose=True
                       )
# save the displacement field
disp = ants.image_read(tx['fwdtransforms'][0])
offset = ants.read_transform(tx['fwdtransforms'][1])
offset = offset.parameters.reshape(4,3)[-1,:]
print(offset)
disp_arr = disp.numpy()
disp_arr[...,[0,1,2]] += offset
disp = ants.from_numpy(disp_arr,origin=disp.origin,spacing=disp.spacing,direction=disp.direction,has_components=disp.has_components,is_rgb=disp.is_rgb)
ants.image_write(disp, "ants_flow.nii.gz")

# resample with pytorch
resampler = SpatialTransformer(size=(160,192,224), mode='bilinear')
flow = ants.image_read( "ants_flow.nii.gz")
flow_arr = flow.numpy()

warped_moving_oh = resampler(moving_oh, flow_tensor.unsqueeze(0))
compute_dice(fixed_oh,warped_moving_oh).mean()

I understand that the displacement field produced by ants.registration is in the Physical point space, and the pytorch resampler works in image index space. But since the spacing is (1.0,1.0,1.0), they are the same from my point of view.

disp.origin:(-80.0, 112.0, 96.0)
disp.direction: array([[ 1.,  0.,  0.],
                       [ 0., -0., -1.],
                       [ 0., -1.,  0.]])
disp.spacing:(1.0,1.0,1.0)

The initial mean dice of the data pair is 0.54, by doing warped_label = ants.apply_transforms(fixed=fixed_label,moving=moving_label,transformlist=tx['fwdtransforms'],interpolator='nearestNeighbor') I get ~0.76, but loading the flow field and use the pytorch resampler I only get ~0.5.

ntustison commented 1 year ago

antsApplyTransforms uses ITK's ResampleImageFilter. Given the overlap with MONAI (PyTorch) and the ITK community, they might be able to provide more informed guidance over at the ITK discourse forum.

BailiangJ commented 1 year ago

@ntustison Thank you for the answer.

I have figured out the solution. I will latter attach a link to a jupyter notebook with the solution in case someone also has the same problem.

mahimoksha commented 1 year ago

@ntustison Thank you for the answer.

I have figured out the solution. I will latter attach a link to a jupyter notebook with the solution in case someone also has the same problem.

Can you please share the code and solution @BailiangJ ? I am facing the same issue !!

BailiangJ commented 1 year ago

@mahimoksha Yes, sure. I modified the code for the special case to the general case but haven't tested it.

from typing import Literal
# Special case with isotropic 1mm spacing
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
    disp = sitk.ReadImage(flow_path)
    direction = torch.tensor(disp.GetDirection()).reshape(3,3)

    # sitk: (x,y,z), numpy(z,y,x)
    disp_arr = sitk.GetArrayFromImage(disp)
    disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
    disp_tensor = torch.from_numpy(disp_arr).float()

    if pkg == "niftyreg":
        nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
        # from nifty space to sitk space
        # the x, y axes are mirrored
        disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)

    # transform displacement in physical point space to 
    # image index space (pytorch)
    # should compute the inverse of affine matrix 
    # but since spacing is isotropic 1mm
    # inverse of affine is inverse of direction
    # direction = torch.linalg.inv(direction)
    # usually the inverse of direction matrix is just itself
    disp_tensor = torch.einsum("ij,jhwd->ihwd", direction, disp_tensor)

    return disp_tensor.unsqueeze(0)

# General case
def load_flow(flow_path:str, pkg:Literal["demons", "ants", "niftyreg"]):
    # displacement fields of ANTs and SimpleITK are both in Physical Point coordinate
    disp = sitk.ReadImage(flow_path)
    direction = torch.tensor(disp.GetDirection()).reshape(3,3)
    spacing = torch.diag(torch.tensor(disp.GetSpacing()))
    # the computed Affine matrix exclude the Origin
    # since we are transforming the displacement vector in Physical Point coordinate
    # to Image Index coordinate, the Origin is not needed
    affine = torch.matmul(direction, spacing)
    # mapping from Image Index coordinate to Physical Point coordinate, so we need the inverse
    affine_inv = torch.linalg.inv(affine)

    # sitk: (x,y,z) -> numpy:(z,y,x)
    disp_arr = sitk.GetArrayFromImage(disp)
    disp_arr = np.transpose(disp_arr, axes=(3,2,1,0)) #(3,H,W,D)
    disp_tensor = torch.from_numpy(disp_arr).float()

    if pkg == "niftyreg":
            nifty_to_sitk = torch.tensor([-1.0,0,0,0,-1.0,0,0,0,1.0]).reshape(3,3)
            # from nifty space to sitk space
            # the x, y axes are mirrored
            disp_tensor = torch.einsum("ij,jhwd->ihwd", nifty_to_sitk, disp_tensor)

    # Physical Point space displacement to Image Index space displacement
    disp_tensor = torch.einsum("ij,jhwd->ihwd", affine_inv, disp_tensor)
    return disp_tensor.unsqueeze(0)

disp_tensor = load_flow()
# use the class I posted before
resampler = SpatialTransformer(size=, mode='bilinear')
warped_image = resampler(image,disp_tensor)
pypi20200320 commented 1 year ago

Hi, how to get the paramter of torch.nn.functional.affine_grid() from the ANTsTransform.parameters ? I cannot figure out what is the ANTsTransform.fixed_parameters, is it the rotate center ? when I use torch.nn.functional.affine_grid() and grid_sample() to get the warped result, it is not the same as Ants output.

BailiangJ commented 1 year ago

Hi, @pypi20200320

From link I think fixed parameters if the rotation center point. (I am not 100% sure.)

If it is the case, I would suggest computing the displacement field directly from the matrix multiplication with an image grid rather than using affine_grid().

You can refer to airlab and take a look how they compute dense flow from affine transformation.

Hope this will help.

pypi20200320 commented 1 year ago

Dear @BailiangJ , thanks so much. I forgot to mention that what I am currently doing is, calculating the relationship between Ants and pytorch transformation parameters for Rigid/Affine registration.

I have refer these informations. And it seems like that the fixed parameters are really the rotate center in ants. So I calculate the physical offset when I put the center to the origin of the fixed image in Ants. And transform the physical offset to pixel offset, then get the offset in Pytorch in which the rotae center is the center of the image.

But I still cannot get the right result, my output "ants_warped.mha" and "torch_warped.mha" are not the same.

Could you help to check what's the problem ? Here is my code:

base_path = '/data/result/test_data/'
name = '0000338846'
fixed = ants.image_read(base_path + f"{name}/{name}_fixed.mha")
moving = ants.image_read(base_path + f"{name}/{name}_moving.mha")
tx = ants.registration(fixed, moving, syn_metric='CC',
                 syn_sampling=3,  # radius
                 type_of_transform='Rigid',
                 grad_step=0.25,
                 flow_sigma=0.9,
                 total_sigma=0.2,
                 reg_iterations=(210, 210, 210),
                 verbose=False)

warped = tx['warpedmovout']
ants.image_write(warped, base_path + f"{name}/{name}_ants_warped.mha")
matrix = ants.read_transform(tx['fwdtransforms'][0])
center = matrix.fixed_parameters
matrix = matrix.parameters.reshape(4, 3)
rotate = matrix[:3, :3]
b_offset = matrix[3, :]
phy_offset = b_offset + center - rotate @ center

moving_data = moving.numpy()
moving_data = np.expand_dims(np.expand_dims(moving_data, 0), 0)
moving_data = torch.tensor(moving_data, dtype=torch.float)

## phy_offset to pixel_offset
direction = fixed.direction.reshape(3, 3)
spacing = np.array(fixed.spacing)
origin = np.array(fixed.origin)
img_size = np.array(fixed.numpy().shape)
print(origin, spacing, direction, img_size)

pixel_offset = phy_offset / spacing
print('pixel_offset:', pixel_offset)

## rotate center is the image_center in pytorch
torch_center = 0.5 * img_size
torch_offset = pixel_offset + torch_center - rotate @ torch_center
print('torch offset:', torch_offset)

torch_matrix = np.eye(4)
torch_matrix[:3, :3] = rotate
torch_matrix[:3, 3] = torch_offset
print('torch_matrix:', torch_matrix)

# offset normalize to [-1, 1]
norm_m = np.eye(4)
norm_m[0, 0] = 2 / img_size[0]
norm_m[1, 1] = 2 / img_size[1]
norm_m[2, 2] = 2 / img_size[2]
norm_m[:3, 3] = -1
inv_norm = np.linalg.inv(norm_m)
torch_matrix = norm_m @ torch_matrix @ inv_norm
print('torch_matrix:', torch_matrix)

torch_matrix = torch.tensor(torch_matrix[:3, :], dtype=torch.float).unsqueeze(0)

# Interpolation
grid = F.affine_grid(torch_matrix,
                   moving_data.size(),
                   align_corners=False)

x_aligned = F.grid_sample(moving_data,
                        grid=grid,
                        mode='bilinear',
                        padding_mode='border',
                        align_corners=False)

aligned = ants.from_numpy(x_aligned.detach().squeeze().numpy().astype(np.float32), has_components=False)
ants.set_origin(aligned, list(origin))
ants.set_spacing(aligned, list(spacing))
ants.image_write(aligned, base_path + f"{name}/{name}_torch_warped.mha")
BailiangJ commented 1 year ago

Hi @pypi20200320 ,

I have never used affine_grid and I don't know the behaviour of it exactly, e.g., whether the rotation center is fixed to the center of the image, whether the input translation to affine_grid should be normalized to [-1.1].

I can see that you are trying to turn

the affine transformation given by ANTs which is defined in physical space

to

the affine transformation that is compatiable with PyTorch, which is defined in the image voxel space.

I think that's more work to do than turning the offset from mm unit (physical point) into voxel unit (image/voxel index) if the direction of the images is not identity matrix. (the transformation between physical space and image space)

image

If you just want to get PyTorch to transform the image using the ANTs affine transformation, you could ask ANTs compute the dense displacement field from it and reuse my code.

If you want to compute the corresponding PyTorch affine matrix from the ANTs affine matrix, you can try to solve the following equation and figure out what should be the matrix applied to the image/voxel index.

QianJianTec1688040735948

anamazingclown commented 1 month ago

@BailiangJ

Hi,

I was trying to use torch.nn.grid_sample to resample the moving image with the displacement field obtained from ants.registration.

However, I can't get the correct resampled result. This might due to my not fully understanding of the arrangement (direction / internal fixed parameters) of ants displacement transform. It would be very nice of you if you could provide some insights. Thanks a lot!

Here is the code:

# pytorch resampler class
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

# ants registration
fixed_image = ants.image_read(fixed_path)
moving_image = ants.image_read(moving_path)
tx = ants.registration(fixed=fixed_image,
                       moving=moving_image,
                       syn_metric='CC',
                       syn_sampling=3,  # radius
                       type_of_transform='SyNOnly',
                       grad_step=0.25,
                       flow_sigma=0.9,
                       total_sigma=0.2,
                       reg_iterations=(210, 210, 210),
                       verbose=True
                       )
# save the displacement field
disp = ants.image_read(tx['fwdtransforms'][0])
offset = ants.read_transform(tx['fwdtransforms'][1])
offset = offset.parameters.reshape(4,3)[-1,:]
print(offset)
disp_arr = disp.numpy()
disp_arr[...,[0,1,2]] += offset
disp = ants.from_numpy(disp_arr,origin=disp.origin,spacing=disp.spacing,direction=disp.direction,has_components=disp.has_components,is_rgb=disp.is_rgb)
ants.image_write(disp, "ants_flow.nii.gz")

# resample with pytorch
resampler = SpatialTransformer(size=(160,192,224), mode='bilinear')
flow = ants.image_read( "ants_flow.nii.gz")
flow_arr = flow.numpy()

warped_moving_oh = resampler(moving_oh, flow_tensor.unsqueeze(0))
compute_dice(fixed_oh,warped_moving_oh).mean()

I understand that the displacement field produced by ants.registration is in the Physical point space, and the pytorch resampler works in image index space. But since the spacing is (1.0,1.0,1.0), they are the same from my point of view.

disp.origin:(-80.0, 112.0, 96.0)
disp.direction: array([[ 1.,  0.,  0.],
                       [ 0., -0., -1.],
                       [ 0., -1.,  0.]])
disp.spacing:(1.0,1.0,1.0)

The initial mean dice of the data pair is 0.54, by doing warped_label = ants.apply_transforms(fixed=fixed_label,moving=moving_label,transformlist=tx['fwdtransforms'],interpolator='nearestNeighbor') I get ~0.76, but loading the flow field and use the pytorch resampler I only get ~0.5.

Hi ,thanks your work! My ultimate goal is to get an affine deformation field that works correctly on SpatialTransformer. Even though my work is to focus on 2d image,it seems that i also meet the same problem . image

I have some questions:

  1. disp" in your work is the displacement filed . But My "fwdtransforms" only have the below item . How can I get the displacement field under the inherent premise of affine registration, and can it be used directly and correctly by SpatialTransformer? image 2.when I read it by "ants.read_transform" I can get image In my opinion ,the parameters in my work is the affine matrix. I have try many work to cal the deformation field for the SpatialTransformer. but it is a pity that I failed . It would be nice if you can give me some help . I am looking foward to recevicing your reply!!
ntustison commented 1 month ago

You can read the description of the affine transform here. You might want to ask over at the ITK discourse forum as someone might already have torch code to do this..

anamazingclown commented 1 month ago

You can read the description of the affine transform here. You might want to ask over at the ITK discourse forum as someone might already have torch code to do this..

Ok, thank you very much for your help. I wish you a happy life and work!

BailiangJ commented 1 month ago

Hi @anamazingclown ,

This link might be helpful regarding getting the displacement field from ANTs affine transformation. After getting the displacement field, you can modify the load_flow function I posted from 3D to 2D and use it with the SpatialTransformer I posted.

This might also be relevant.

anamazingclown commented 1 month ago

Hi @anamazingclown ,

This link might be helpful regarding getting the displacement field from ANTs affine transformation. After getting the displacement field, you can modify the load_flow function I posted from 3D to 2D and use it with the SpatialTransformer I posted.

This might also be relevant.

Ok, I will try, thank you very much!!!!!!