fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

Extract control points for elastic random deformations #442

Closed YarivLevy81 closed 3 years ago

YarivLevy81 commented 3 years ago

🚀 Feature

Extract transformation parameters for random deformations

Motivation (and Pitch)

Like in @fepegar gist, we want to be able to observe the control points(xx, yy), as well as the deformation vectors u, v.

That being said, ideally I've been thinking about something like -

import torchio as tio

transform = tio.RandomElasticDeformation(...)
u, v = transform.get_deformation_vectors()
x, y = transform.get_control_points()

Alternatives

I've already seen that there is a get_params method in RandomElasticDeformation, but I think that it generates new parameters for each method call. am I right? This method seems to provide a solution for u, v but not x, y.

Additional context

I mainly wondered if I miss anything, I really went into the documentation, source code and the gist but wasn't 100% sure the feature is really missing.

If it does miss, I would love to help with a pull request.

fepegar commented 3 years ago

Hi, @YarivLevy81. I am not 100% sure of what you want to get. Here's some code:

In [1]: import torch

In [2]: import torchio as tio

In [3]: subject = tio.datasets.Sheep()

In [4]: transform = tio.RandomElasticDeformation()

In [5]: torch.manual_seed(42)
Out[5]: <torch._C.Generator at 0x7fdab0106990>

In [6]: transformed = transform(subject)

sheep

From there you can play with the transforms history stored in the subject instance to get the control points:

In [12]: deterministic = transformed.history[-1]

In [13]: deterministic
Out[13]: 
ElasticDeformation(control_points=[[[[
[...]

In [14]: itk_transform = deterministic.get_bspline_transform(subject.t1.as_sitk(), deterministic.control_points)

In [16]: x, y, z = itk_transform.GetCoefficientImages()

In [17]: x.GetSize()
Out[17]: (7, 7, 7)

In [18]: x.GetOrigin()
Out[18]: (-90.46875, -108.21875, -89.71875)

In [19]: x.GetSpacing()
Out[19]: (30.15625, 39.65625, 30.40625)

And that's how you would get the control points.

fepegar commented 3 years ago

We can work together to add a method ElasticDeformation.get_control_points(image), if you think that would be useful.

fepegar commented 3 years ago

This is some of the relevant code: https://github.com/fepegar/torchio/blob/6533a620cf18708fc7275a12b0d2544c3f1f0637/torchio/transforms/augmentation/spatial/random_elastic_deformation.py#L215-L225

I guess the names of the variables could be a bit better, as what's called control_points is actually the displacement values at the control points, not the points themselves.

Anyway, as you can see, you need three things to get the control points: the image, the spline order (which is set to three, the default) and the size of the control points grid.

YarivLevy81 commented 3 years ago

Thanks, @fepegar.

I'm sorry if I was unclear, I do think that ElasticDeformation.get_control_points(image) would be useful. My goal is to make some very coarse estimation of the optical flow, based of the ElasticDeformation. So I need both the displacement values (control_points) and the points themselves.

What I thought about doing is: https://github.com/fepegar/torchio/blob/6533a620cf18708fc7275a12b0d2544c3f1f0637/torchio/transforms/augmentation/spatial/random_elastic_deformation.py#L115-L140

Add something like -

self.control_points = self.get_params(
            self.num_control_points,
            self.max_displacement,
            self.num_locked_borders,
        ) 

And then in apply_transform method replace this - https://github.com/fepegar/torchio/blob/6533a620cf18708fc7275a12b0d2544c3f1f0637/torchio/transforms/augmentation/spatial/random_elastic_deformation.py#L174-L178

to

arguments = {
            'control_points': self.control_points,
            'max_displacement': self.max_displacement,
            'image_interpolation': self.image_interpolation,
        }

so the displacement values are extractable.

Afterward, I still need to be able to extract the points themselves, and as you mentioned this is doable via get_bspline_transform. Do you think it would be helpful to add an implementation into torchio?

fepegar commented 3 years ago

Sorry, it's still not clear to me what's missing.

My goal is to make some very coarse estimation of the optical flow, based of the ElasticDeformation

This optical flow, aka displacement field, is ElasticDeformation.control_points. You can get the points using the method above. So it seems that what you're after is already possible. Is your proposal to write a method so it's easier to access the points?

YarivLevy81 commented 3 years ago

I think I misunderstood. I do indeed understand how to get the displacement values now, but how can I extract the origin points from the ElasticDeformation object?

fepegar commented 3 years ago

It's all in this comment.

If you have the size, origin, direction and spacing of an image, you can compute the position of all the points using SimpleITK.Image.TransformIndexToPhysicalPoint. In the following code, x (from the comment above) corresponds to the u values, and you can get the position of each voxel / control point:

In [10]: size_i, size_j, size_k = x.GetSize()

In [11]: points = []

In [13]: for i in range(size_i):
    ...:     for j in range(size_j):
    ...:         for k in range(size_k):
    ...:             point = x.TransformIndexToPhysicalPoint((i, j, k))
    ...:             points.append(point)
    ...: 

In [14]: import numpy as np

In [15]: points = np.array(points)

In [16]: points.shape
Out[16]: (343, 3)

In [17]: points[:10]
Out[17]: 
array([[ -90.46875, -108.21875,  -89.71875],
       [ -90.46875, -108.21875,  -59.3125 ],
       [ -90.46875, -108.21875,  -28.90625],
       [ -90.46875, -108.21875,    1.5    ],
       [ -90.46875, -108.21875,   31.90625],
       [ -90.46875, -108.21875,   62.3125 ],
       [ -90.46875, -108.21875,   92.71875],
       [ -90.46875,  -68.5625 ,  -89.71875],
       [ -90.46875,  -68.5625 ,  -59.3125 ],
       [ -90.46875,  -68.5625 ,  -28.90625]])

As you can see, the position of the first voxel of the coarse displacement field is the origin, -90.46875, -108.21875, -89.71875.

These coordinates are in LPS orientation with respect to the patient, i.e., x grows to the left, y grows posteriorly and z grows towards the top of the head.

Does this make sense at all? Let me know if it's not clear.

YarivLevy81 commented 3 years ago

Thank you very much for the explanation @fepegar.

Just to validate that I understood you correctly.

  1. The u, v, w displacement scalars each correspond with x, y, z direction (LPS because we use DICOM images), to extract them I only need to call:
    
    import torch
    import torchio as tio

subject = ... transform = tio.RandomElasticDeformation() torch.manual_seed(42) transformed = transform(subject)

deterministic = transformed.history[-1] control_points = deterministic.control_points u = control_points[..., 0].T v = control_points[..., 1].T w = control_points[..., 2].T

This gives me a 3-tuple of (u, v, w) for each point.

2. The origin points are extractable using your code above:
```python
itk_transform = deterministic.get_bspline_transform(subject.t1.as_sitk(), deterministic.control_points)

x, y, z = itk_transform.GetCoefficientImages()
size_i, size_j, size_k = x.GetSize()

points = []

for i in range(size_i):
    for j in range(size_j):
        for k in range(size_k):
            point = x.TransformIndexToPhysicalPoint((i, j, k))
            points.append(point)

import numpy as np

points = np.array(points)

If I would run the same thing but using y.TransformIndexToPhysicalPoint((i, j, k)) I would get the same points, right?

  1. Now that I have both points and u, v, w the following code, for example, would print every control point with it's displacement value:
    
    u = u.flatten()
    for i in range(u.shape[0]):
        print(f'Point -> {points[i]}: u -> {u[i]}, v -> {v[i]}, w -> {w[i]}')

''' ... Point -> [-319.6875 -255.625 159.6875]: u -> 0.0, v -> 0.0, w -> 0.0 Point -> [-319.6875 -191.5625 -32.6875]: u -> 0.0, v -> 0.0, w -> 0.0 Point -> [-319.6875 -191.5625 -0.625 ]: u -> 0.0, v -> 0.0, w -> 0.0 Point -> [-319.6875 -191.5625 31.4375]: u -> -4.260419845581055, v -> 1.3956239223480225, w -> -1.8879514932632446 Point -> [-319.6875 -191.5625 63.5 ]: u -> 0.4929801821708679, v -> -4.902942180633545, w -> -4.997309684753418 Point -> [-319.6875 -191.5625 95.5625]: u -> -4.458999156951904, v -> -0.8299738168716431, w -> -6.557217597961426 ... '''


The last point correspond to `[-319.6875 -191.5625   95.5625]` from origin (LPS) and `(u, v, w) = (-4.458999156951904, -0.8299738168716431, -6.557217597961426)` (LPS).
fepegar commented 3 years ago

Looks good to me :)

Just one small comment. LPS orientation is because ITK uses that convention (which DICOM also uses). In e.g. the VTK world, the coordinates would be RAS (I think. Orientation stuff is confusing!).

Displacement at the borders are zero because of the locked_borders kwarg in RandomElasticDeformation.

YarivLevy81 commented 3 years ago

Thanks! good to know about the LPS orientation. I close this to not confuse anyone.