nipy / nitransforms

a standalone fork of nipy/nibabel#656
https://nitransforms.readthedocs.io
MIT License
28 stars 15 forks source link

Notes on incorporating voxel shift maps and per-volume transforms #184

Open effigies opened 9 months ago

effigies commented 9 months ago

Voxel shift maps

As mentioned in today's call, I think that adding the full fieldmap into nitransforms would be difficult. But thinking a bit more about voxel shifts, I think we can do that fairly straightforwardly as an argument to apply():

class TransformBase:

    def apply(
        self,
        spatialimage,
        reference=None,
        voxel_shift_map=None,
        order=3,
        mode="constant",
        cval=0.0,
        prefilter=True,
        output_dtype=None,
    ):
        ...
        # Ignoring transposes and homogeneous coordinates for brevity
        rascoords = self.map(reference.ndcoords)
        voxcoords = Affine(spatialimage.affine).map(rascoords).reshape((reference.ndim, *reference.shape))
        if voxel_shift_map:
            # voxel_shift_map must have shape (reference.ndim, *reference.shape)
            # Alternately, we could accept it in (*reference.shape, reference.ndim) and roll axes
            voxcoords += voxel_shift_map

        resampled = ndi.map_coordinates(
            data,
            voxcoords,
            output=output_dtype,
            order=order,
            mode=mode,
            cval=cval,
            prefilter=prefilter,
        )

Because map operates on RAS coordinates and not voxel indices, we cannot use it in that context, so we probably do not want to include it as part of the transform itself.

We specifically do not want to describe voxel shift maps in the world space of the target image. While it may be possible to fit it at the end of the chain, after motion correction transforms, any solution would be more complicated than the above.

Per-volume transformations

The above discussion works for an individual volume. In order to correctly handle VSMs in a motion-corrected frame, we need TransformChains to become aware that they are involved in a per-volume transform. Unfortunately, right now, TransformChains are iterable over transforms, while LinearTransformsMapping are iterable over volumes, which at the very least means straightforward API composition isn't going to work.

Currently, LinearTransformsMapping operates in apply():

https://github.com/nipy/nitransforms/blob/1674e86a73595356eb6a775fd5b5c612952482a0/nitransforms/linear.py#L395-L498

A VSM+multivolume-aware TransformChain could do what we want in apply(). Another thought is that we could treat transforms as data objects and not actors. The interface could be:

def apply_transform(
    source: SpatialImage,
    target: Pointset,
    transform: TransformBase,
    shift_map: np.ndarray,
    # map_coordinates args
    ...
) -> np.ndarray:
    ...

If we give up on defining apply() correctly for each transform, and leave them to focus on composing and mapping, it might make things cleaner. Just imagining how we might approach chains that include per-volume transforms:

class TransformBase:
    n_transforms: int = 1

    def iter_transforms(self) -> Iterator[TransformBase]:
        """Repeat current transform as often as required"""
        return itertools.repeat(self)

class AffineSeries(TransformBase):
    @property
    def n_transforms(self) -> int:
        return len(self.series)

    def iter_transforms(self) -> Iterator[TransformBase]:
        """Iterate over the defined series"""
        return iter(self.series)

class TransformChain(TransformBase):
    @property
    def n_transforms(self) -> int:
        lengths = [xfm.n_transforms for xfm in self.chain if xfm.n_transforms != 1]
        return min(lengths) if lengths else 1

    def iter_transforms(self) -> Iterator[TransformChain]:
        """Iterate over all transforms in chain, simultaneously, stopping with first to stop"""
        return map(TransformChain, zip(*(xfm.iter_transforms() for xfm in self.chain)))