Open iyerkrithika21 opened 3 months ago
@HastingsGreer : Here is my attempt at it, you can see if the logic makes sense and matches the original implementation. I have used the same helper functions.
import torch.nn.functional as F
import torch
import glob
import os
import numpy as np
def create_identity_map(image_size, spacing=None):
"""
Create an identity map for a given image size.
:param image_size: Tuple specifying the size of the image (X, Y, Z) for 3D, (X, Y) for 2D, or (X,) for 1D.
:param spacing: Tuple specifying the spacing between elements in each dimension. Defaults to (1, 1, 1).
:return: Identity map as a numpy array.
"""
# Set default spacing if not provided
if spacing is None:
spacing = tuple([1.0] * len(image_size))
# Generate a grid of coordinates
coordinates = np.mgrid[[slice(0, s) for s in image_size]]
# Convert the coordinates to float and apply spacing
identity_map = np.array(coordinates, dtype=np.float32)
for i in range(len(image_size)):
identity_map[i] *= spacing[i]
return torch.from_numpy(identity_map).permute([1, 2, 0])
def inverse_consistency_loss(phi_AB_tensor, phi_BA_tensor, identity_map):
"""
Calculate the inverse consistency loss between two deformation fields tensors
by converting them to function form.
Parameters:
- phi_AB_tensor: Deformation field from space A to space B (tensor of shape [B, 2, H, W]).
- phi_BA_tensor: Deformation field from space B to space A (tensor of shape [B, 2, H, W]).
- identity_map: The identity map (grid) (tensor of shape [B, C, H, W]
Returns:
- loss: The inverse consistency loss and gradient inverse consistency loss
NOTE:
vectorfields = deformation field + identity_map
Deformation field: output of the U-Net in the case of gradicon model
"""
# Repeat the identity map across the batch dimension
batch_size = phi_AB_tensor.shape[0]
device = phi_AB_tensor.device
size = phi_AB_tensor[2:]
phi_AB = tensor2function(phi_AB_tensor, size)
phi_BA = tensor2function(phi_BA_tensor, size)
voxel_spacing = (torch.max(identity_map)+1)/identity_map.shape[-1]
Iepsilon = (identity_map + torch.randn(*identity_map.shape).to(identity_map.device)* voxel_spacing)
approximate_Iepsilon1 = phi_AB(phi_BA(Iepsilon))
approximate_Iepsilon2 = phi_BA(phi_AB(Iepsilon))
inverse_consistency_loss = torch.mean((Iepsilon - approximate_Iepsilon1) ** 2) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2)
direction_losses = []
approximate_Iepsilon = phi_AB(phi_BA(Iepsilon))
inverse_consistency_error = Iepsilon - approximate_Iepsilon
delta = 1e-6
if len(identity_map.shape) == 4:
dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(identity_map.device)
dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(identity_map.device)
direction_vectors = (dx, dy)
elif len(identity_map.shape) == 5:
dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(
identity_map.device
)
dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(
identity_map.device
)
dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(
identity_map.device
)
direction_vectors = (dx, dy, dz)
elif len(identity_map.shape) == 3:
dx = torch.Tensor([[[delta]]]).to(identity_map.device)
direction_vectors = (dx,)
for d in direction_vectors:
approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d))
inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
grad_d_icon_error = (
inverse_consistency_error - inverse_consistency_error_d
) / delta
direction_losses.append(torch.mean(grad_d_icon_error**2))
grad_inverse_consistency_loss = sum(direction_losses)
return inverse_consistency_loss, grad_inverse_consistency_loss
def scale_map(map, sz, spacing):
"""
Scales the map to the [-1,1]^d format
:param map: map in BxCxXxYxZ format
:param sz: size of image being interpolated in XxYxZ format
:param spacing: spacing of image in XxYxZ format
:return: returns the scaled map
"""
map_scaled = torch.zeros_like(map)
ndim = len(spacing)
# This is to compensate to get back to the [-1,1] mapping of the following form
# id[d]*=2./(sz[d]-1)
# id[d]-=1.
for d in range(ndim):
if sz[d + 2] > 1:
map_scaled[:, d, ...] = (
map[:, d, ...] * (2.0 / (sz[d + 2] - 1.0) / spacing[d])
- 1.0
# map[:, d, ...] * 2.0 - 1.0
)
else:
map_scaled[:, d, ...] = map[:, d, ...]
return map_scaled
class STNFunction_ND_BCXYZ:
"""
Spatial transform function for 1D, 2D, and 3D. In BCXYZ format (this IS the format used in the current toolbox).
"""
def __init__(
self, spacing, zero_boundary=False, using_bilinear=True, using_01_input=True):
"""
Constructor
:param ndim: (int) spatial transformation of the transform
"""
self.spacing = spacing
self.ndim = len(spacing)
# zero_boundary = False
self.zero_boundary = "zeros" if zero_boundary else "border"
self.mode = "bilinear" if using_bilinear else "nearest"
self.using_01_input = using_01_input
def forward_stn(self, input1, input2, ndim):
if ndim == 1:
# use 2D interpolation to mimick 1D interpolation
# now test this for 1D
phi_rs = input2.reshape(list(input2.size()) + [1])
input1_rs = input1.reshape(list(input1.size()) + [1])
phi_rs_size = list(phi_rs.size())
phi_rs_size[1] = 2
phi_rs_ordered = torch.zeros(
phi_rs_size, dtype=phi_rs.dtype, device=phi_rs.device
)
# keep dimension 1 at zero
phi_rs_ordered[:, 1, ...] = phi_rs[:, 0, ...]
output_rs = torch.nn.functional.grid_sample(
input1_rs,
phi_rs_ordered.permute([0, 2, 3, 1]),
mode=self.mode,
padding_mode=self.zero_boundary,
align_corners=True,
)
output = output_rs[:, :, :, 0]
if ndim == 2:
# todo double check, it seems no transpose is need for 2d, already in height width design
# input2_ordered = torch.zeros_like(input2)
# input2_ordered[:, 0, ...] = input2[:, 1, ...]
# input2_ordered[:, 1, ...] = input2[:, 0, ...]
input2_ordered = input2
if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1)
'''
input = [N,C,H,W]
grid = [N,H,W,2]
output = [N,C,H,W]
'''
output = torch.nn.functional.grid_sample(
input=input1,
grid=input2_ordered.permute([0, 2, 3, 1]),
mode=self.mode,
padding_mode=self.zero_boundary,
align_corners=True,
)
if ndim == 3:
input2_ordered = torch.zeros_like(input2)
input2_ordered[:, 0, ...] = input2[:, 2, ...]
input2_ordered[:, 1, ...] = input2[:, 1, ...]
input2_ordered[:, 2, ...] = input2[:, 0, ...]
if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1, -1)
output = torch.nn.functional.grid_sample(
input1,
input2_ordered.permute([0, 2, 3, 4, 1]),
mode=self.mode,
padding_mode=self.zero_boundary,
align_corners=True,
)
return output
def __call__(self, input1, input2):
"""
Perform the actual spatial transform
:param input1: image in BCXYZ format
:param input2: spatial transform in BdimXYZ format
:return: spatially transformed image in BCXYZ format
"""
assert len(self.spacing) + 2 == len(input2.size())
if self.using_01_input:
output = self.forward_stn(
input1, scale_map(input2, input1.shape, self.spacing), self.ndim
)
else:
output = self.forward_stn(input1, input2, self.ndim)
return output
def compute_warped_image_multiNC(I0, phi, zero_boundary=False):
"""Warps image.
:param I0: image to warp, image size BxCxXxYxZ
:param phi: map for the warping, size BxdimxXxYxZ
:param spacing: image spacing [dx,dy,dz]
:return: returns the warped image of size BxCxXxYxZ
"""
spacing = I0.shape[2:]
f = STNFunction_ND_BCXYZ(spacing, zero_boundary)
"""
Simply returns the transformed input
:param input1: image in BCXYZ format
:param input2: map in BdimXYZ format
:return: returns the transformed image
"""
return f(I0, phi)
def as_function(image):
"""image is a tensor
Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...]
into a tensor of intensities.
"""
return lambda coordinates: compute_warped_image_multiNC(
I0=image, phi=coordinates
)
def tensor2function(tensor_of_displacements, spacing):
displacement_field = as_function(tensor_of_displacements)
def transform(coordinates,isIdentity=False):
if isIdentity and coordinates.shape == tensor_of_displacements.shape:
return coordinates + tensor_of_displacements
return coordinates + displacement_field(coordinates)
return transform
Hi @iyerkrithika21 ,
I have created standalone ICON and GradICON losses in PR #80 .
I hope it will also help. :)
I am interested in calculating the GradICON and ICON loss for a network I am working on, and it does not follow the typical registration workflow with two input images and predicting the deformation fields, so I cannot use
InverseConsistentNet
orGradientICON
orGradientICONSparse.
Unfortunately, I cannot share my code.Is it possible to provide a standalone loss function that calculates the losses, given two deformation field matrices (equivalent to your
phi_AB.vectorfield
andphi_BA.vectorfield
variables?Also, could you briefly explain the format of the GradICON output? I am confused about the use of
phi_AB
as a function vs. the actual deformation field in the image coordinates. Thank you!