BioMedIA / deepali

Image, point set, and surface registration in PyTorch.
https://biomedia.github.io/deepali/
Apache License 2.0
26 stars 6 forks source link

SpatialTransform with respect to different coordinate systems (Axes) #52

Open schuhschuh opened 1 year ago

schuhschuh commented 1 year ago

Currently, a SpatialTransform is a coordinate mapping of normalized coordinates with respect to the reference sampling Grid specified when constructing the spatial transformation object. A user of the transform has to map coordinates with respect to other Axes to these normalized coordinates before applying the transformation, and back again. In contexts where no normalized coordinates are used to represent the spatial transformation because torch.nn.functional.grid_sample is not being utilized, one might even want to parameterize the spatial transformation with respect to other coordinate systems (e.g., the world coordinate system). These use cases could be supported by adding a axes: Axes property and optionally a corresponding keyword argument to SpatialTranform.forward(.., axes: Optional[Axes] = None) which can be used to define the parameterization axes of the coordinate map and further to specify with respect to which axes the input points are defined.

Thanks @qiuhuaqi for this suggestion. Let's discuss here benefits or any reason why we wouldn't support this.

aschuh-hf commented 1 year ago

While this feature request is different from #53, and may still be desirable for those who want to implement a classic registration which directly operates in world space rather than normalized point coordinates, a possible work-around is to use SpatialTransform.points() instead of having spatial transforms whose parameterization is with respect to unnormalized coordinates spaces.

See #62.

qiuhuaqi commented 1 year ago

sorry I must've missed the review request for #62. I can see this is a good workaround for my usecase, which is transforming the points which are in the un-normalised space (world space) with SpatialTransform.

For image registration operating in the world space, we can internally handle all the spatial normalisation to make sure everything is correct when SpatialTransform is being used. So it would be indeed a bit redundant. But it would be somehow cleaner IMO if the entire SpatialTransform module considers axes. For example, if one wants to collapse the transform to tensor format, I think SpatialTransform.tensor() method right now can only produce fields in the normalised space? It could just be enough to just add axes as an argument to any relevant functions.