uncbiag / ICON

A library for performing image registration using deep learning, regularized by inverse consistency
Other
42 stars 9 forks source link

GradICON and ICON loss without using network wrapper #79

Open iyerkrithika21 opened 3 months ago

iyerkrithika21 commented 3 months ago

I am interested in calculating the GradICON and ICON loss for a network I am working on, and it does not follow the typical registration workflow with two input images and predicting the deformation fields, so I cannot use InverseConsistentNet or GradientICON or GradientICONSparse. Unfortunately, I cannot share my code.

Is it possible to provide a standalone loss function that calculates the losses, given two deformation field matrices (equivalent to your phi_AB.vectorfield and phi_BA.vectorfield variables?

Also, could you briefly explain the format of the GradICON output? I am confused about the use of phi_AB as a function vs. the actual deformation field in the image coordinates. Thank you!

iyerkrithika21 commented 3 months ago

@HastingsGreer : Here is my attempt at it, you can see if the logic makes sense and matches the original implementation. I have used the same helper functions.

import torch.nn.functional as F
import torch
import glob
import os
import numpy as np

def create_identity_map(image_size, spacing=None):
    """
    Create an identity map for a given image size.

    :param image_size: Tuple specifying the size of the image (X, Y, Z) for 3D, (X, Y) for 2D, or (X,) for 1D.
    :param spacing: Tuple specifying the spacing between elements in each dimension. Defaults to (1, 1, 1).
    :return: Identity map as a numpy array.
    """
    # Set default spacing if not provided
    if spacing is None:
        spacing = tuple([1.0] * len(image_size))

    # Generate a grid of coordinates
    coordinates = np.mgrid[[slice(0, s) for s in image_size]]

    # Convert the coordinates to float and apply spacing
    identity_map = np.array(coordinates, dtype=np.float32)
    for i in range(len(image_size)):
        identity_map[i] *= spacing[i]

    return torch.from_numpy(identity_map).permute([1, 2, 0])

def inverse_consistency_loss(phi_AB_tensor, phi_BA_tensor, identity_map):
    """
    Calculate the inverse consistency loss between two deformation fields tensors
    by converting them to function form.

    Parameters:
    - phi_AB_tensor: Deformation field from space A to space B (tensor of shape [B, 2, H, W]).
    - phi_BA_tensor: Deformation field from space B to space A (tensor of shape [B, 2, H, W]).
    - identity_map: The identity map (grid) (tensor of shape [B, C, H, W]

    Returns:
    - loss: The inverse consistency loss and gradient inverse consistency loss

    NOTE: 
    vectorfields = deformation field + identity_map
    Deformation field: output of the U-Net in the case of gradicon model

    """

    # Repeat the identity map across the batch dimension
    batch_size = phi_AB_tensor.shape[0]
    device = phi_AB_tensor.device

    size = phi_AB_tensor[2:]
    phi_AB = tensor2function(phi_AB_tensor, size)
    phi_BA = tensor2function(phi_BA_tensor, size)

    voxel_spacing = (torch.max(identity_map)+1)/identity_map.shape[-1]
    Iepsilon = (identity_map + torch.randn(*identity_map.shape).to(identity_map.device)* voxel_spacing)

    approximate_Iepsilon1 = phi_AB(phi_BA(Iepsilon))

    approximate_Iepsilon2 = phi_BA(phi_AB(Iepsilon))

    inverse_consistency_loss = torch.mean((Iepsilon - approximate_Iepsilon1) ** 2) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2)

    direction_losses = []

    approximate_Iepsilon = phi_AB(phi_BA(Iepsilon))

    inverse_consistency_error = Iepsilon - approximate_Iepsilon

    delta = 1e-6

    if len(identity_map.shape) == 4:
        dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(identity_map.device)
        dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(identity_map.device)
        direction_vectors = (dx, dy)

    elif len(identity_map.shape) == 5:
        dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(
            identity_map.device
        )
        dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(
            identity_map.device
        )
        dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(
            identity_map.device
        )
        direction_vectors = (dx, dy, dz)
    elif len(identity_map.shape) == 3:
        dx = torch.Tensor([[[delta]]]).to(identity_map.device)
        direction_vectors = (dx,)

    for d in direction_vectors:

        approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d))
        inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
        grad_d_icon_error = (
            inverse_consistency_error - inverse_consistency_error_d
        ) / delta
        direction_losses.append(torch.mean(grad_d_icon_error**2))

    grad_inverse_consistency_loss = sum(direction_losses)

    return inverse_consistency_loss, grad_inverse_consistency_loss

def scale_map(map, sz, spacing):
    """
    Scales the map to the [-1,1]^d format
    :param map: map in BxCxXxYxZ format
    :param sz: size of image being interpolated in XxYxZ format
    :param spacing: spacing of image in XxYxZ format
    :return: returns the scaled map
    """

    map_scaled = torch.zeros_like(map)
    ndim = len(spacing)

    # This is to compensate to get back to the [-1,1] mapping of the following form
    # id[d]*=2./(sz[d]-1)
    # id[d]-=1.

    for d in range(ndim):
        if sz[d + 2] > 1:
            map_scaled[:, d, ...] = (
                map[:, d, ...] * (2.0 / (sz[d + 2] - 1.0) / spacing[d])
                - 1.0
                # map[:, d, ...] * 2.0 - 1.0
            )
        else:
            map_scaled[:, d, ...] = map[:, d, ...]

    return map_scaled

class STNFunction_ND_BCXYZ:
    """
    Spatial transform function for 1D, 2D, and 3D. In BCXYZ format (this IS the format used in the current toolbox).
    """

    def __init__(
        self, spacing, zero_boundary=False, using_bilinear=True, using_01_input=True):
        """
        Constructor
        :param ndim: (int) spatial transformation of the transform
        """
        self.spacing = spacing
        self.ndim = len(spacing)
        # zero_boundary = False
        self.zero_boundary = "zeros" if zero_boundary else "border"
        self.mode = "bilinear" if using_bilinear else "nearest"
        self.using_01_input = using_01_input

    def forward_stn(self, input1, input2, ndim):
        if ndim == 1:
            # use 2D interpolation to mimick 1D interpolation
            # now test this for 1D
            phi_rs = input2.reshape(list(input2.size()) + [1])
            input1_rs = input1.reshape(list(input1.size()) + [1])

            phi_rs_size = list(phi_rs.size())
            phi_rs_size[1] = 2

            phi_rs_ordered = torch.zeros(
                phi_rs_size, dtype=phi_rs.dtype, device=phi_rs.device
            )
            # keep dimension 1 at zero
            phi_rs_ordered[:, 1, ...] = phi_rs[:, 0, ...]

            output_rs = torch.nn.functional.grid_sample(
                input1_rs,
                phi_rs_ordered.permute([0, 2, 3, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )
            output = output_rs[:, :, :, 0]

        if ndim == 2:
            # todo double check, it seems no transpose is need for 2d, already in height width design
            # input2_ordered = torch.zeros_like(input2)
            # input2_ordered[:, 0, ...] = input2[:, 1, ...]
            # input2_ordered[:, 1, ...] = input2[:, 0, ...]
            input2_ordered = input2

            if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
                input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1)
            '''
            input = [N,C,H,W]
            grid = [N,H,W,2]
            output = [N,C,H,W]
            '''

            output = torch.nn.functional.grid_sample(
                input=input1,
                grid=input2_ordered.permute([0, 2, 3, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )

        if ndim == 3:
            input2_ordered = torch.zeros_like(input2)
            input2_ordered[:, 0, ...] = input2[:, 2, ...]
            input2_ordered[:, 1, ...] = input2[:, 1, ...]
            input2_ordered[:, 2, ...] = input2[:, 0, ...]
            if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
                input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1, -1)
            output = torch.nn.functional.grid_sample(
                input1,
                input2_ordered.permute([0, 2, 3, 4, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )

        return output

    def __call__(self, input1, input2):
        """
        Perform the actual spatial transform
        :param input1: image in BCXYZ format
        :param input2: spatial transform in BdimXYZ format
        :return: spatially transformed image in BCXYZ format
        """

        assert len(self.spacing) + 2 == len(input2.size())
        if self.using_01_input:
            output = self.forward_stn(
                input1, scale_map(input2, input1.shape, self.spacing), self.ndim
            )
        else:
            output = self.forward_stn(input1, input2, self.ndim)

        return output

def compute_warped_image_multiNC(I0, phi, zero_boundary=False):
    """Warps image.
    :param I0: image to warp, image size BxCxXxYxZ
    :param phi: map for the warping, size BxdimxXxYxZ
    :param spacing: image spacing [dx,dy,dz]
    :return: returns the warped image of size BxCxXxYxZ
    """
    spacing = I0.shape[2:]
    f = STNFunction_ND_BCXYZ(spacing, zero_boundary)
    """
    Simply returns the transformed input
    :param input1: image in BCXYZ format
    :param input2: map in BdimXYZ format
    :return: returns the transformed image
    """
    return f(I0, phi)

def as_function(image):
    """image is a tensor 
    Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...]
    into a tensor of intensities.
    """

    return lambda coordinates: compute_warped_image_multiNC(
        I0=image, phi=coordinates
    )

def tensor2function(tensor_of_displacements, spacing):

    displacement_field = as_function(tensor_of_displacements)

    def transform(coordinates,isIdentity=False):
        if isIdentity and coordinates.shape == tensor_of_displacements.shape:

            return coordinates + tensor_of_displacements
        return coordinates + displacement_field(coordinates)

    return transform
BailiangJ commented 2 months ago

Hi @iyerkrithika21 ,

I have created standalone ICON and GradICON losses in PR #80 .

I hope it will also help. :)