uncbiag / uniGradICON

The official website for uniGradICON: A Foundation Model for Medical Image Registration
Apache License 2.0
66 stars 4 forks source link

Obtain Displacement Field at Original Input dimension #21

Open anudeepk17 opened 1 month ago

anudeepk17 commented 1 month ago

Dear Authors, Thank you for this great content. I was using your model from the source and was trying to register my own data. The results are great but I wanted to obtain the phi_AB, or the net.phi_AB_vectorfield in the original image input dimensions instead of 175x175x175. I tried to resample the vectorfield using itk.resample_image_filter and the torch.nn.Interpolate fucntion but both resulted in a tensor which was not able to register the images like the field of original network size. I could not figure out how to modify phi_AB since it is a itk.COmpositeTransform format object.

Could you help me understand how to get the net.phi_AB_vectorfield or phi_AB in the dimension of the input image we give. My issue is similar to #15 but I could not get a solution from that issue. Would again like to thank you for the help in advance.

HastingsGreer commented 1 month ago

Hi! Are you looking for the displacement field at original resolution with displacements in physical coordinates? If so, this is doable by converting the itkCompositeTransform to a displacement field as follows:

https://colab.research.google.com/drive/1bo_CWdI4PC7YdMmlVb2jYp0Fd1ee5gW_?usp=sharing

!pip install unigradicon

!wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_1.nrrd
!wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_2.nrrd

!unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd

import itk

fixed_image = itk.imread("RegLib_C01_2.nrrd")
moving_image = itk.imread("RegLib_C01_1.nrrd")

transform = itk.transformread("trans.hdf5")[0]

dispfield_filter = itk.TransformToDisplacementFieldFilter[itk.Image[itk.Vector[itk.F, 3], 3], itk.D].New()

dispfield_filter.SetTransform(transform)
dispfield_filter.SetReferenceImage(fixed_image)
dispfield_filter.SetUseReferenceImage(True)

dispfield_filter.Update()

displacement_field = dispfield_filter.GetOutput()

displacement_field.GetLargestPossibleRegion().GetSize()

print(np.array(displacement_field).shape)

warped_moving_image = itk.warp_image_filter(
    moving_image,
    output_origin=fixed_image.GetOrigin(),
    output_direction=fixed_image.GetDirection(),
    output_spacing=fixed_image.GetSpacing(),
    displacement_field=displacement_field)
plt.imshow(itk.checker_board_image_filter(fixed_image, warped_moving_image)[50])

Does this work for your usecase?

anudeepk17 commented 1 month ago

Hello Author, Thank you for your reply. While this did help in getting a vectorfield in the original dimension the dice scores compared to the vectorfiled obtained in size 175,175,175 is very less. Below is the code of how I am calculating dice . I have added comments to clarify my approach to the best of my abilities. I obtain phi_AB from fixed and moving images and then using that phi_AB to warp masks of substructures in my data.

My issue is that the dice I obtained with the phi_AB in the network shape i.e, 175,175,175 is very good but when I use the displacement_field and obtain the warped mask the dice is decreasing drastically.

`

Obtain phi_AB and warped_label according to obtained phi_AB

phi_AB,phi_BA,net = get_dvf(fixed_path,moving_path)

# Use the obtained phi_AB to register label mask of a substructure of our data
#The warp_image function is similar to warp_command as per your code, it uses itk.resample_image_filter to return warped image of label.
# I pass original size label and phi_AB(of network size , 175x175x175) as transform to obtain warped image of original size of input
warped_label = warp_image(fixed_label_path, moving_label_path, phi_AB)

# Load and calculate dice
fixed_label = nib.load(fixed_label_path).get_fdata() #path to label 
moving_label = nib.load(moving_label_path).get_fdata()
#dim0: R-L; dim1: A-P; dim2: S-I
warped_label = itk.GetArrayFromImage(warped_label)
#warped_label= np.array(warped_label.cpu())
warped_label = warped_label.swapaxes(0, 2)
warped_label = sitk.GetImageFromArray(warped_label)
fixed_label = sitk.GetImageFromArray(fixed_label)
moving_label = sitk.GetImageFromArray(moving_label)
warped_label = sitk.GetArrayFromImage(warped_label)
warped_label_new = np.zeros(np.shape(warped_label))
warped_label_new[warped_label > 255.0 * 0.5] = 255.0
warped_label = sitk.GetImageFromArray(warped_label_new)
dice_175 = compute_metric_dsc(warped_label, fixed_label,auto_crop = False)
dice_pre = compute_metric_dsc(moving_label, fixed_label,auto_crop = False)

# Solution provided by author to get a displacement field vector in original input size of image.
# Read the label and calculate new displacement field and obtain warped image
moving_label=itk.imread(moving_label_path)
fixed_image=itk.imread(fixed_path)
dispfield_filter = itk.TransformToDisplacementFieldFilter[itk.Image[itk.Vector[itk.F, 3], 3], itk.D].New()

dispfield_filter.SetTransform(phi_AB)
dispfield_filter.SetReferenceImage(fixed_image)
dispfield_filter.SetUseReferenceImage(True)

dispfield_filter.Update()

displacement_field = dispfield_filter.GetOutput()

displacement_field.GetLargestPossibleRegion().GetSize()

# print(np.array(displacement_field).shape)

warped_moving_image = itk.warp_image_filter(
    moving_label,
    output_origin=fixed_image.GetOrigin(),
    output_direction=fixed_image.GetDirection(),
    output_spacing=fixed_image.GetSpacing(),
    displacement_field=displacement_field)

# Dice calculation similar to above, this time using the warped_moving_image
fixed_label = nib.load(fixed_label_path).get_fdata() #path to label 
moving_label = nib.load(moving_label_path).get_fdata()

#dim0: R-L; dim1: A-P; dim2: S-I

warped_label=itk.array_from_image(warped_moving_image)
warped_label = warped_label.swapaxes(0, 2)
warped_label = sitk.GetImageFromArray(warped_label)
fixed_label = sitk.GetImageFromArray(fixed_label)
moving_label = sitk.GetImageFromArray(moving_label)

warped_label = sitk.GetArrayFromImage(warped_label)
warped_label_new = np.zeros(np.shape(warped_label))
warped_label_new[warped_label > 255.0 * 0.5] = 255.0
warped_label = sitk.GetImageFromArray(warped_label_new)
dice_interpolated= compute_metric_dsc(warped_label, fixed_label,auto_crop = False)`

Here is the table of various different samples I tried, PreRegistration Dice is the dice before registering the two masks Dice_175 is the dice after registering the masks through the network without modifying phi_AB Dice_SameSize is the dice after using displacement_field method to obtain the warped masks

image
HastingsGreer commented 1 month ago

Could you provide the definitions of the functions get_dvf and warp_image? My suspicion is that somewhere in the pipeline the image metadata (spacing, orientation, and origin) is getting lost.

Also, could you provide the output of the following script for the fixed image, the fixed label, the moving image, and the moving label? This will help me understand the image metadata.

I'm sorry that this is taking so much effort to clear up!

import itk
print(itk.imread(fixed_image_path))
print(itk.imread(moving_image_path))
print(itk.imread(fixed_label_path))
print(itk.imread(moving_image_path))
anudeepk17 commented 4 weeks ago

Hello no issues at all , I am just glad and thankful for your help and guidance. Here is the code you need:

def get_dvf(fixed,moving,save_dvf=None,transform_out=None,fixed_segmentation=None,moving_segmentation=None,io_iterations="None",moving_modality='mri',fixed_modality='mri'):
    ''' fixed               : Path of the fixed image
        moving              : Path of moving image
        save_dvf            : True if want to save dvf as an hdf5 file
        transform_out       : Path of the hdf5 file
        fixed_segmentation  : Path of segmentation map of fixed Image
        moving_segentation  : Path of segmentation map of moving Image
        io_iterations       : Default none, number of iterations.
        moving_modality     : 'ct' or 'mri'
        fixed_modality      : 'ct' or 'mri'
    '''
    net = get_unigradicon()
    fixed = itk.imread(fixed)
    moving = itk.imread(moving)

    if fixed_segmentation is not None:
        fixed_segmentation = itk.imread(fixed_segmentation)
    else:
        fixed_segmentation = None

    if moving_segmentation is not None:
        moving_segmentation = itk.imread(moving_segmentation)
    else:
        moving_segmentation = None

    if io_iterations == "None":
        io_iterations = None
    else:
        io_iterations = int(io_iterations)

    phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
        net,
        preprocess(moving, moving_modality, moving_segmentation), 
        preprocess(fixed, fixed_modality, fixed_segmentation), 
        finetune_steps=io_iterations)
    if save_dvf is not None:
        if transform_out is None:
            transform_out="trans.hdf5"
        itk.transformwrite([phi_AB], transform_out)
    return phi_AB,phi_BA,net

def warp_image(moving,fixed,phi_AB=None,transform=None,interpolator=None,save_img=None,warped_moving_out=None):
        '''
        fixed               : Path of the fixed image
        moving              : Path of moving image
        phi_AB              : Phi returned from get_dvf()
        interpolator        : Linear or nearest_neighbor
        save_img            : If want o save image
        warped_moving_out   : Path of the image to be saved.
        transform           : path of hdf5 saved transform 

        '''
        fixed = itk.imread(fixed)
        moving = itk.imread(moving)
        if interpolator=="linear" or interpolator is None:
            interpolator = itk.LinearInterpolateImageFunction.New(moving)
        elif interpolator=="nearest_neighbor":
            interpolator = itk.NearestNeighborInterpolateImageFunction.New(moving)
        else:
            raise Exception("Specify --nearest_neighbor or --linear")
        if transform is not None and phi_AB is None:
             phi_AB = itk.transformread(transform)[0]
        elif phi_AB is None and transform is None:
            raise Exception("Specify either transform path or Phi_AB as returned from get_dvf()")
        interpolator = itk.LinearInterpolateImageFunction.New(moving)
        warped_moving_image = itk.resample_image_filter(
                moving,
                transform=phi_AB,
                interpolator=interpolator,
                use_reference_image=True,
                reference_image=fixed
                )
        if save_img is not None:
            if warped_moving_out is None:
                warped_moving_out="warp.nii.gz"
            itk.imwrite(warped_moving_image, warped_moving_out)
        else:
            return warped_moving_image

The output for the script:

Image (0x6420ee118250)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241236
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241057
  UpdateMTime: 241235
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241232
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x6420d3e46e10)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241603
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241424
  UpdateMTime: 241602
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241599
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x642014a483d0)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241970
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241791
  UpdateMTime: 241969
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241966
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x6420ee118250)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 242337
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 242158
  UpdateMTime: 242336
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 242333
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Thank you again for your prompt responses and help. I look forward to your reply.

HastingsGreer commented 3 weeks ago

This is a real puzzler! I see three possibilities for what is going on:

1) It is generally best practice to warp label images using itk.NearestNeighborInterpolateImageFunction instead of itk.LinearInterpolateImageFunction. It is worth converting both paths (transform and displacement field) to make sure that they use Nearest Neighbor interpolation- and maybe this would fix the discrepancy, maybe it would not

2) Examining your metadata, I realized that we have not extensively tested our approach on registering images with resolution much lower than 175 x 175 x 175. It is possible that this case is exposing a bug in our code- or the model is producing a very high resolution displacement field that is somehow "cheating", hiding labels in between the low resolution pixels, and the "cheating" is defeated by forcing the displacement field to have the same resolution as the image

3) There are some confusing elements in the code you have posted- in particular, converting between numpy arrays and itk images is tricky, and the calls to swapaxis are easy to mess up. Also, it is not clear how the code posted works together- the call to warp_image in the first code sample leaves interpolator as None, but the code for warp_image throws an error if interpolator is None.

Would you be willing to email me full runnable code and a pair of images from your table? I know that this may be a data sharing issue, but I have reached a dead end with the information I have. I think I need an example I can run and experiment on to resolve this issue.