SuperElastix / SimpleElastix

Multi-lingual medical image registration library
http://simpleelastix.github.io
Apache License 2.0
507 stars 149 forks source link

Trouble understanding basic usage #496

Open efournie opened 1 year ago

efournie commented 1 year ago

Hello,

I have trouble understanding the basic usage of SimpleElastix. In the following example, I try to regiser a moving image Im (small square in the center) to a fixed image If (big square in the upper left corner). In the second part, I try to apply the computed transformation to Im using a TransformixImageFilter. The third part does the same using only the deformation field returned by the TransformixImageFilter in the second part.

import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

def test_shape(size, origin):
    image = np.zeros((512, 512))
    image[origin[0]+size, origin[1]:origin[1]+size] = 1
    image[origin[0]:origin[0]+size, origin[1]+size] = 1
    for xy in range(size):
        image[origin[0]+xy, origin[1]:origin[1]+size-xy] = 1
    return sitk.GetImageFromArray(image)

If = test_shape(150, [100, 100])
Im = test_shape(100, [200, 200])
# Part 1
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(If)
elastixImageFilter.SetMovingImage(Im)
elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("affine"))
elastixImageFilter.Execute()
Tx = elastixImageFilter.GetTransformParameterMap()
result1 = sitk.GetArrayFromImage(elastixImageFilter.GetResultImage())
# Part 2
transformixImageFilter = sitk.TransformixImageFilter()
transformixImageFilter.SetTransformParameterMap(Tx)
transformixImageFilter.SetMovingImage(Im)
transformixImageFilter.ComputeDeformationFieldOn()
transformixImageFilter.Execute()
deformationField_np = sitk.GetArrayFromImage(transformixImageFilter.GetDeformationField())
result2 = sitk.GetArrayFromImage(transformixImageFilter.GetResultImage())
# Part 3
transformixImageFilter2 = sitk.TransformixImageFilter()
deformationField = sitk.GetImageFromArray(deformationField_np.astype(np.float64), isVector=True)
deformationField.SetOrigin(sitk.Image.GetOrigin(If))
deformationField.SetSpacing(sitk.Image.GetSpacing(If))
transform = sitk.DisplacementFieldTransform(deformationField)
resample = sitk.ResampleImageFilter()
resample.SetSize(sitk.Image.GetSize(Im))
resample.SetTransform(transform)
resample.SetInterpolator(sitk.sitkNearestNeighbor)
result3 = sitk.GetArrayFromImage(resample.Execute(Im))

_, a = plt.subplots(1,3)
a[0].imshow(result1 + sitk.GetArrayFromImage(If) / 2 + sitk.GetArrayFromImage(Im) / 4)
a[1].imshow(result2 + sitk.GetArrayFromImage(If) / 2 + sitk.GetArrayFromImage(Im) / 4)
a[2].imshow(result3 + sitk.GetArrayFromImage(If) / 2 + sitk.GetArrayFromImage(Im) / 4)
plt.show()

When I display the resulting images (results pixels = 1, If pixels = 0.5, Im pixels = 0.25), I see that the three result images (deformations of Im) are very close to Im although I would expect them to be close to If.

Does someone have a clue about what I am doing wrong?

Thank you for your help!

SimpleElastixResults