BioMedIA / deepali

Image, point set, and surface registration in PyTorch.
https://biomedia.github.io/deepali/
Apache License 2.0
26 stars 6 forks source link

[Feature]: Support PaddingMode NONE in 3D #87

Closed vasl12 closed 1 year ago

vasl12 commented 1 year ago

The dimensions of a 3D image change after Affine registration. Is this related to padding and if yes how can we fix it?

aschuh-hf commented 1 year ago

Hi Vasilis, could you elaborate on how you apply the affine transformation in order to resample your moving source image?

The PaddingMode only relates to how the finite input source image tensor is implicitly padded when sampling at locations outside the source image domain. It does not impact the sampling grid of the warped output image. The sampling grid of the spatially transformed image is usually determined by the sampling grid of the fixed target image, or the SpatialTransform.grid() respectively as specified during initialization of the spatial transform.

In case of the ImageTransformer (and equally deepali.modules.SampleImage) this is the target argument of the spatial transformer init function.

https://github.com/BioMedIA/deepali/blob/f6dd09fbbed05267d4d00df2c9c951ca66eb9472/src/deepali/spatial/transformer.py#L127

vasl12 commented 1 year ago

Hey Andreas, My sampling grid is determined by the sampling grid of the fixed image (1, 224, 168, 363) which is the same as the moving image dims as well. After the affine registration the warped image shape is (1, 225, 169, 361). Do you have any intuition then why this might be happening? Thank you for the help :)

aschuh-hf commented 1 year ago

That's odd indeed. What is the code you use to apply the spatial transform?

vasl12 commented 1 year ago

Here is the piece of code for the affine multi-resolution registration and the warping:

source = tranformed_images[0]['wat'][tio.DATA]
target = ts['wat'][tio.DATA]

transform_affine = multi_resolution_registration(
    target=target,
    source=source,
    transform=spatial.FullAffineTransform,
    optimizer=(optim.Adam, {"lr": 1e-3}),
    loss_fn=lambda a, b, _: sim_loss(a, b),
    device=device,
    levels=3,
    iterations=1000
)
transform_affine = transform_affine.cpu()
source = tranformed_images[0]['wat'][tio.DATA]
target = ts['wat'][tio.DATA]

transform_affine = multi_resolution_registration(
    target=target,
    source=source,
    transform=spatial.FullAffineTransform,
    optimizer=(optim.Adam, {"lr": 1e-3}),
    loss_fn=lambda a, b, _: sim_loss(a, b),
    device=device,
    levels=3,
    iterations=1000
)
transform_affine = transform_affine.cpu()
jmtzt commented 1 year ago

Hi @aschuh-hf. I'm encountering the same issue. I'm using a deformable transform, as in:

transform = multi_resolution_registration(
        target=fixed_pyramid,
        source=moving_pyramid,
        transform=("FFD", {"stride": ffd_stride}),
        optimizer=("Adam", {"lr": lr}),
        loss_fn=loss_fn(w_bending=w_bend,
                        w_reg=w_reg,
                        w_diffusion=w_diff,
                        w_curvature=w_curv,
                        sim_loss=sim_loss),
        levels=len(fixed_pyramid),
        device=device,
        iterations=iterations,
    )
with torch.inference_mode():
        transformer = spatial.ImageTransformer(transform)
        warped: Tensor = transformer(moving_img_tensor)
        warped_source_seg: Tensor = transformer(moving_seg_og_tensor)

I've already specified padding=None in the ImageTransformer but even then the shape of the warped tensor gets changed (moving_img_tensor shape is (1, 189, 233, 197), the warped shape is (1, 193, 233, 201)). I couldn't point out exactly where inside the ImageTransformer this might be happening, do you have any intuition on this? Thanks :)

aschuh-hf commented 1 year ago

I see. The change in image size is probably caused by the multi-resolution pyramid. The sizes are adjusted such that doubling the size going from a coarser to the next finer level will eventually yield the closest size at the finest resolution to the given input image size. This is to enable efficient subdivision of a FFD control point grid.

The demo function would probably have to adopted a bit to account for this such that the output has the expected size. This can be done by passing the desired output image grid as target to the ImageTransformer used to apply the obtained transform. The transform.grid() is not identical to the input target.grid() here, but it's the former that's used when the image transformer is created as in ImageTransformer(transform).

Can you try changing the final transformer to ImageTransformer(transform, target=target.grid()) instead?

jmtzt commented 1 year ago

Hi @aschuh-hf, thanks for the quick reply, the issue with this approach is that now I'm getting the following error:

ValueError: ImageTransformer() 'target' and 'transform' grid must define the same domain

Any way to fix this? It looks like the domains are somehow changed when using the pyramid here. Thanks

aschuh-hf commented 1 year ago

@jmtzt The Image.pyramid() (and respectively Grid.pyramid()) functions should preserve the domain, i.e., all images in the pyramid have the same Grid.cube().

Also in the tutorial notebook we have original MNIST size 28x28, but the *_pyramid images have sizes:

>>> [tuple(im.grid().size()) for im in target_pyramid.values()]
[(25, 25), (13, 13), (7, 7)]

We still get

>>> [im.grid().same_domain_as(grid) for im in target_pyramid.values()]
[True, True, True]

Can you check this also for your image to see if Grid.same_domain_as() returns True for all levels of the image pyramid created from your input image? If not, then maybe there's a bug.

Besides this, I'll have a look to maybe modify the ImageTransformer to remove the need to have the same domain. At the moment, you would need to apply your image transform using lower-level functionality than the ImageTransformer.

from deepali.core import grid_transform_points, functional as U

source_data: Tensor = source
source_grid: Grid = grid
output_grid: Grid = grid

# Normalized coordinates with respect to output_grid domain
x = output_grid.coords()
# Map to normalized coordinates with respect to transform.grid() domain
x = grid_transform_points(x, output_grid, output_grid.axes(), transform.grid(), transform.axes())
# Apply spatial transformation
#
# Option grid=True is used to indicate that input points are located at regularly sampled grid points,
# which can be more efficient for non-rigid transforms. The result should be the same with grid=False (default).
with torch.inference_mode():
    x = transform(x, grid=True)
# Map to normalized coordinates with respect to moving source image domain
x = grid_transform_points(x, transform.grid(), transform.axes(), source_grid, source_grid.axes())

# Sample moving source image at transformed output grid points
warped_data = U.grid_sample(source_data.unsqueeze(0), x, align_corners=source_grid.align_corners()).squeeze_(0)

warped_image = Image(warped_data, output_grid)
aschuh-hf commented 1 year ago

I've just merged a change which allows the sampling grid of the transformed image (output of ImageTransformer) to have a different normalized coordinate domain in world space than the spatial transform itself.